Skip to content

Conversation

@ved1beta
Copy link
Contributor

@ved1beta ved1beta commented Aug 21, 2025

Description

feature 1 : The merge-lora script does not load the model into memory, period. It just iterates through each of the bin or safetensors shards and applies the lora to each module as it needs. It's extremely efficient compared to the standard approach.

new file lora_merge_efficient core implementation
new parameter merge_method : standard /memory efficient

Motivation and Context

#1679

references

qlora-pipe/tools/merge_lora.py

Tests

tested with examples/llama-3/qlora-1b.yml
with tiny llama 1 b instruct and merge_methode:memory efficient

Summary by CodeRabbit

  • New Features

    • Adds a memory-efficient LoRA merge that processes model shards without loading the full model; includes a legacy in-memory merge fallback when needed.
  • Chores

    • Configurable merge method (default: memory_efficient), improved logging (method choice and per-shard progress), clearer CLI messaging, and safer merged output handling.
  • Documentation

    • Updated config schema and docstrings to describe both merge strategies; public API unchanged.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Aug 21, 2025

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

📝 Walkthrough

Walkthrough

Adds a shard-wise, memory-efficient LoRA merging utility and integrates it into the CLI with a dispatch that prefers the memory-efficient method (default) and falls back to the legacy in-memory merge on RuntimeError; also adds a merge_method PEFT config field defaulting to "memory_efficient".

Changes

Cohort / File(s) Change Summary
CLI merge dispatch & helpers
src/axolotl/cli/merge_lora.py
Add merge_method handling (default "memory_efficient"); log chosen method; import merge_lora_sharded_efficient; implement _do_merge_lora_legacy (in-memory) and _do_merge_lora_efficient (shard-wise); update do_merge_lora to dispatch with a RuntimeError fallback to legacy; adjust CLI messages and docstring.
Memory-efficient LoRA merge utility
src/axolotl/utils/lora_merge_efficient.py
New module implementing get_model_shards, find_lora_weights, copy_non_model_files, and merge_lora_sharded_efficient. Supports .safetensors and .bin shards, reads adapter config to compute scale, applies per-shard LoRA deltas without loading the full model, preserves safetensors metadata when possible, copies non-model files, and performs per-shard memory cleanup and logging.
PEFT schema update
src/axolotl/utils/schemas/peft.py
Add `merge_method: Literal["legacy","memory_efficient"]

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

✨ Finishing touches
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@ved1beta ved1beta marked this pull request as ready for review August 22, 2025 11:42
@codecov
Copy link

codecov bot commented Aug 22, 2025

Codecov Report

❌ Patch coverage is 14.81481% with 138 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/axolotl/cli/utils/lora_merge.py 12.58% 125 Missing ⚠️
src/axolotl/cli/merge_lora.py 23.52% 13 Missing ⚠️

📢 Thoughts on this report? Let us know!

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (12)
src/axolotl/utils/schemas/peft.py (2)

132-138: Enum looks good; consider documenting/deprecating the old boolean to avoid confusion

Introducing merge_method: Literal["standard", "memory_efficient"] is a clear improvement. However, merge_lora: bool | None still exists and is referenced elsewhere, which can be confusing to users. Recommend:

  • Mark merge_lora as deprecated in the field’s description and CLI help.
  • In docs/examples, prefer merge_method and explain when to set merge_lora=True (still used by validators/CLI).

154-177: Validator still keys off merge_lora; clarify interaction with merge_method

validate_qlora gates behavior on self.merge_lora. If a user configures only merge_method="memory_efficient" but forgets merge_lora=True, validation may incorrectly follow the "not merging" path. Since do_cli() currently forces merge_lora=True, this is fine for the CLI path, but configs used programmatically may drift.

Option A (minimal): Document that merge_lora must be True when invoking a merge, regardless of merge_method.

Option B (preferred): Make the validator robust by checking an explicit “are we merging” signal that can be derived from the CLI invocation or merge_method when the merge entrypoint is called. If changing semantics is risky, emit a warning when merge_method != "standard" and merge_lora is not True.

src/axolotl/utils/lora_merge_efficient.py (7)

60-87: Skip index files and preserve metadata when copying; avoid needless large-file copies

  • If get_model_shards missed some shard patterns, copy_non_model_files can end up copying large shard files only to overwrite them later.
  • Also skip index JSONs (*.index.json) and use copy2 to preserve timestamps and metadata.
@@
-    for filepath in input_path.glob("*"):
+    for filepath in input_path.glob("*"):
         if filepath.is_dir():
             continue
         if filepath.name in shard_names:
             continue
+        # Skip HF index files and other model indices
+        if filepath.name.endswith(".index.json"):
+            continue
         if filepath.suffix == ".gguf":
             continue
         if filepath.name.startswith("model") and filepath.suffix == ".safetensors":
             continue
@@
-        shutil.copy(filepath, output_path)
+        shutil.copy2(filepath, output_path)

100-109: Graceful fallback when CUDA is unavailable

Defaulting to "cuda" can fail on CPU-only hosts. Fall back to CPU and log once.

@@
-    base_model_path = Path(base_model_path)
+    base_model_path = Path(base_model_path)
@@
-    output_path = Path(output_path)
+    output_path = Path(output_path)
+
+    if device == "cuda" and not torch.cuda.is_available():
+        LOG.warning("CUDA not available; falling back to CPU for merge.")
+        device = "cpu"

115-119: Guard against missing/zero rank in adapter config

Division by zero or None will raise without a helpful message.

-    scale = lora_config.lora_alpha / lora_config.r
+    if not getattr(lora_config, "r", None):
+        raise ValueError("Invalid LoRA config: rank `r` is missing or zero.")
+    scale = lora_config.lora_alpha / lora_config.r

130-134: Compat: torch.load(weights_only=True) isn’t available on older Torch

Provide a fallback for wider compatibility.

-    else:
-        lora_state = torch.load(lora_file, map_location="cpu", weights_only=True)
+    else:
+        try:
+            lora_state = torch.load(lora_file, map_location="cpu", weights_only=True)
+        except TypeError:
+            # Fallback for older torch versions
+            lora_state = torch.load(lora_file, map_location="cpu")

154-181: Reduce GPU memory pressure; ensure CPU tensors before saving; handle fan_in_fan_out

  • Reading tensors straight onto CUDA and accumulating them in merged_tensors keeps a whole shard on GPU; flip results back to CPU eagerly.
  • Apply fan_in_fan_out if present to avoid incorrect orientation on models that require it.
@@
-                for key in f.keys():
+                for key in f.keys():
                     total_tensors += 1
                     tensor = f.get_tensor(key)
                     lora_a, lora_b = find_lora_weights(lora_state, key)
@@
-                        delta = scale * (
-                            lora_b.to(torch.float32) @ lora_a.to(torch.float32)
-                        )
-
-                        merged_tensor = (tensor_fp32 + delta).to(original_dtype)
-                        merged_tensors[key] = merged_tensor
+                        delta = scale * (lora_b.to(torch.float32) @ lora_a.to(torch.float32))
+                        # Handle fan_in_fan_out if present
+                        if getattr(lora_config, "fan_in_fan_out", False):
+                            delta = delta.T
+                        merged_tensors[key] = (tensor_fp32 + delta).to(original_dtype).cpu()
                     else:
-                        merged_tensors[key] = tensor
+                        merged_tensors[key] = tensor.to("cpu")

182-199: Mirror CPU-offload approach for .bin shards

Same memory considerations should apply when shards are .bin.

@@
-            for key, tensor in state_dict.items():
+            for key, tensor in state_dict.items():
                 total_tensors += 1
                 lora_a, lora_b = find_lora_weights(lora_state, key)
@@
-                    delta = scale * (
-                        lora_b.to(torch.float32) @ lora_a.to(torch.float32)
-                    )
-                    merged_tensors[key] = (tensor_fp32 + delta).to(original_dtype)
+                    delta = scale * (lora_b.to(torch.float32) @ lora_a.to(torch.float32))
+                    if getattr(lora_config, "fan_in_fan_out", False):
+                        delta = delta.T
+                    merged_tensors[key] = (tensor_fp32 + delta).to(original_dtype).cpu()
                 else:
-                    merged_tensors[key] = tensor
+                    merged_tensors[key] = tensor.to("cpu")

135-139: Optional: Avoid moving entire LoRA state to GPU upfront

Moving the whole LoRA state to CUDA may spike memory on large adapters. Consider keeping LoRA on CPU and moving just the A/B tensors for the tensor being merged in the loop (or using pinned memory). This is a targeted optimization and can be a follow-up.

src/axolotl/cli/merge_lora.py (3)

24-29: Validate merge_method early and log the choice

Guard against typos and make behavior explicit.

-    merge_method = getattr(cfg, "merge_method", "standard")
+    merge_method = getattr(cfg, "merge_method", "standard")
+    if merge_method not in ("standard", "memory_efficient"):
+        raise ValueError(f"Invalid merge_method: {merge_method!r}. Expected 'standard' or 'memory_efficient'.")
+    LOG.info("Selected LoRA merge method: %s", merge_method)

88-91: Nit: docstring typo

Use load_in_4bit (with underscore) for consistency with the actual config key.

-    (`load_in_8bit=False`, `load_in4bit=False`, `flash_attention=False`, etc.).
+    (`load_in_8bit=False`, `load_in_4bit=False`, `flash_attention=False`, etc.).

113-118: Support remote LoRA adapters (HF Hub) or improve error message

The memory-efficient implementation currently requires a local adapter directory. Consider:

  • Supporting Hub IDs by calling snapshot_download (similar to base model logic), or
  • Clarifying the message to “LoRA adapter directory does not exist” to avoid implying the merged output should already exist.
-        raise ValueError(
-            f"Target directory for LoRA merged model does not exist: `{parsed_cfg.lora_model_dir}`"
-        )
+        raise ValueError(
+            f"LoRA adapter directory does not exist: `{parsed_cfg.lora_model_dir}`. "
+            "Provide a local path to the adapter (directory containing adapter_config.json)."
+        )
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 0fa752e and d0c0116.

📒 Files selected for processing (3)
  • src/axolotl/cli/merge_lora.py (3 hunks)
  • src/axolotl/utils/lora_merge_efficient.py (1 hunks)
  • src/axolotl/utils/schemas/peft.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/axolotl/utils/lora_merge_efficient.py (2)
src/axolotl/utils/schemas/peft.py (1)
  • LoraConfig (28-191)
src/axolotl/utils/logging.py (1)
  • get_logger (42-49)
src/axolotl/cli/merge_lora.py (3)
src/axolotl/utils/lora_merge_efficient.py (1)
  • merge_lora_sharded_efficient (89-214)
src/axolotl/utils/logging.py (1)
  • get_logger (42-49)
src/axolotl/utils/dict.py (1)
  • DictDefault (6-38)
🪛 Ruff (0.12.2)
src/axolotl/utils/lora_merge_efficient.py

159-159: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: pre-commit
  • GitHub Check: preview
  • GitHub Check: pre-commit

Comment on lines 24 to 28
merge_method = getattr(cfg, "merge_method", "standard")
if merge_method == "memory_efficient":
_do_merge_lora_efficient(cfg=cfg)
else:
_do_merge_lora_standard(cfg=cfg)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's be opinionated and prefer the memory_efficient merge. the old "standard" version can be a manual fallback if users have issues. You could also do something like wrap the _do_merge_lora_efficient in a try/catch and if it fails, provide a hint to the user to use the standard option

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe rename standard to legacy or something else since it's no longer axolotl's "standard"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing to consider is that this only works for standard LoRA, and not other advanced methods like DoRA

@djsaunde
Copy link
Member

curious if you have any numbers on how much peak VRAM is saved?

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (3)
src/axolotl/utils/lora_merge_efficient.py (3)

22-43: Fix key parsing and matching: strip(".weight") is incorrect; substring matching risks false positives

  • strip(".weight") removes any of the characters in ".weight" from both ends, corrupting keys.
  • Matching with if clean_key in lora_key can attach the wrong adapters when names share substrings.

Apply an exact-suffix removal and precise endswith-based matching, and bail early once both A/B are found:

@@
-def find_lora_weights(
-    lora_state: Dict[str, torch.Tensor], key: str
-) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
+def find_lora_weights(
+    lora_state: Dict[str, torch.Tensor], key: str
+) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
@@
-    clean_key = key.strip(".weight")
+    # Remove only the exact ".weight" suffix if present
+    clean_key = key[:-7] if key.endswith(".weight") else key
     clean_key = re.sub(r"^(base_model\.model\.|language_model\.)", "", clean_key)
@@
-    for lora_key, lora_weight in lora_state.items():
-        if clean_key in lora_key:
-            if "lora_A" in lora_key:
-                lora_a = lora_weight
-            elif "lora_B" in lora_key:
-                lora_b = lora_weight
+    suffixes_a = (f"{clean_key}.lora_A.weight", f"{clean_key}.lora_A.default.weight")
+    suffixes_b = (f"{clean_key}.lora_B.weight", f"{clean_key}.lora_B.default.weight")
+    for lora_key, lora_weight in lora_state.items():
+        if any(lora_key.endswith(sfx) for sfx in suffixes_a):
+            lora_a = lora_weight
+        elif any(lora_key.endswith(sfx) for sfx in suffixes_b):
+            lora_b = lora_weight
+        if lora_a is not None and lora_b is not None:
+            break

200-209: Do not rename .bin shards to .safetensors; always save CPU tensors; preserve original format

Renaming .bin to .safetensors while using torch.save produces invalid files and breaks index JSONs. Also, both safetensors and torch.save should receive CPU tensors.

Apply this fix:

-        output_shard_path = output_path / shard_path.name
-        if safe_tensors and shard_path.suffix == ".safetensors":
-            safetensors.torch.save_file(
-                merged_tensors, output_shard_path, metadata=metadata
-            )
-        else:
-            if safe_tensors:
-                output_shard_path = output_shard_path.with_suffix(".safetensors")
-            torch.save(merged_tensors, output_shard_path)
+        output_shard_path = output_path / shard_path.name
+        # Ensure CPU tensors before writing
+        merged_tensors = {k: v.detach().cpu() for k, v in merged_tensors.items()}
+        if shard_path.suffix == ".safetensors":
+            safetensors.torch.save_file(merged_tensors, output_shard_path, metadata=metadata)
+        else:
+            # Preserve .bin format to keep HF index consistency
+            torch.save(merged_tensors, output_shard_path)

46-58: Runtime error: listPath is not a constructor; leave patterns as HF-standard

list[Path]() will raise at runtime. Initialize with a literal/list() instead. Patterns look good per transformers conventions (pytorch_model*.bin, model*.safetensors).

 def get_model_shards(model_path: Path) -> list[Path]:
     """Find all model shards in the given path."""
-    shards = list[Path]()
+    shards: list[Path] = []
@@
-    patterns = ["model*.safetensors", "model*.bin", "pytorch_model*.bin"]
+    patterns = ["model*.safetensors", "model*.bin", "pytorch_model*.bin"]
🧹 Nitpick comments (7)
src/axolotl/utils/lora_merge_efficient.py (7)

120-134: Defensive load for older Torch versions (weights_only not available everywhere)

torch.load(..., weights_only=True) is not present in all Torch versions and may raise TypeError.

Wrap with a compatibility fallback:

-    if lora_file.suffix == ".safetensors":
-        lora_state = safetensors.torch.load_file(lora_file)
-    else:
-        lora_state = torch.load(lora_file, map_location="cpu", weights_only=True)
+    if lora_file.suffix == ".safetensors":
+        lora_state = safetensors.torch.load_file(lora_file)
+    else:
+        try:
+            lora_state = torch.load(lora_file, map_location="cpu", weights_only=True)
+        except TypeError:
+            # Torch < 2.3 compatibility
+            lora_state = torch.load(lora_file, map_location="cpu")

If CI uses an older Torch, this prevents a hard failure.


135-139: Optional: avoid transferring LoRA weights to GPU unless necessary

Moving the entire LoRA state to GPU can be costly and unnecessary if shards are read on CPU for save. If GPU memory is tight, keep LoRA on CPU and cast selectively when computing deltas.

-    if device != "cpu":
+    if device != "cpu":
         LOG.info(f"Moving LoRA weights to {device}")
         for key, value in tqdm(lora_state.items(), desc="Moving LoRA to device"):
             lora_state[key] = value.to(device)

Consider gating this by a config flag (e.g., lora_on_cpu for this path defaulting to True) or only moving lora_a/lora_b on demand inside the merge loop.


182-199: torch.load on .bin shards: keep tensors on CPU if you intend to save CPU tensors

You map to device, which could be CUDA, but safetensors expects CPU tensors on save. You address this later—see next comment. For clarity and lower VRAM pressure, consider map_location="cpu" here and move per-tensor only for math if you truly need GPU.

-            state_dict = torch.load(
-                shard_path, map_location=device
-            )  # nosec B614: loading trusted model weights
+            state_dict = torch.load(
+                shard_path, map_location="cpu"
+            )  # nosec B614: loading trusted model weights

211-213: Guard CUDA-only cache clearing to avoid errors on non-CUDA devices

Calling torch.cuda.empty_cache() when CUDA is unavailable or when device is not CUDA can error on some setups.

-        if device != "cpu":
-            torch.cuda.empty_cache()
+        if isinstance(device, str) and device.startswith("cuda") and torch.cuda.is_available():
+            torch.cuda.empty_cache()

111-118: Verify LoraConfig loader API; provide fallback for environments without from_json_file

Some PEFT versions don’t expose LoraConfig.from_json_file. If that’s the case in your CI, parse JSON and construct LoraConfig directly, or use LoraConfig.from_pretrained(lora_adapter_path).

-    lora_config = LoraConfig.from_json_file(config_file)
+    try:
+        lora_config = LoraConfig.from_json_file(config_file)  # type: ignore[attr-defined]
+    except AttributeError:
+        import json
+        with open(config_file) as f:
+            cfg = json.load(f)
+        lora_config = LoraConfig(**cfg)

150-178: Performance: avoid O(N_params × N_lora_tensors) scanning by pre-indexing LoRA

Current approach scans all LoRA entries for every model tensor. Pre-index once into a dict of base_key → (A, B) to reduce time on large models.

I can provide a follow-up patch to build an index like:

  • Parse lora_state keys, normalize (strip prefixes), map base_key → (A, B).
  • In the loop, just lookup = lora_index.get(clean_key).

Want me to draft this refactor?


96-109: Remote path detection is heuristic; ok for now

The slash check is pragmatic. If you later add support for local paths with slashes that don’t exist yet, consider a more explicit “is HF repo id” flag or try/except around snapshot_download.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 95f224c and 4ccf5ae.

📒 Files selected for processing (1)
  • src/axolotl/utils/lora_merge_efficient.py (1 hunks)
🧰 Additional context used
🧠 Learnings (4)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.434Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.
📚 Learning: 2025-08-22T13:23:41.434Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.434Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.

Applied to files:

  • src/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.387Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.387Z
Learning: HuggingFace transformers uses standard patterns `pytorch_model*.bin` and `model*.safetensors` for model shards, as defined in transformers/utils/__init__.py. Additional patterns like `pytorch_model*.safetensors` are not necessary for standard HF model discovery.

Applied to files:

  • src/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.387Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.387Z
Learning: HuggingFace transformers uses these standard filename patterns: WEIGHTS_NAME = "pytorch_model.bin", SAFE_WEIGHTS_NAME = "model.safetensors" (not "pytorch_model.safetensors"), and sharded files follow "pytorch_model-*.bin" and "model-*.safetensors" patterns. The patterns "pytorch_model*.bin" and "model*.safetensors" are sufficient for discovering HF model shards.

Applied to files:

  • src/axolotl/utils/lora_merge_efficient.py
🧬 Code graph analysis (1)
src/axolotl/utils/lora_merge_efficient.py (1)
src/axolotl/utils/logging.py (1)
  • get_logger (42-49)
🪛 Ruff (0.12.2)
src/axolotl/utils/lora_merge_efficient.py

159-159: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: pre-commit
  • GitHub Check: preview
🔇 Additional comments (5)
src/axolotl/utils/lora_merge_efficient.py (5)

150-181: Numerical safety: ensure delta and base cast paths are consistent

You correctly compute in fp32 then cast back. Consider also handling mismatched device/dtype gracefully (e.g., bf16 base). Current code already promotes to float32, which is fine. No action required—just noting it’s correct.


60-87: Copy rules look sane; confirm index files are preserved

The logic preserves pytorch_model*.bin.index.json and model*.safetensors.index.json, and skips GGUF and shard payloads. This is what we want for HF repos.

Please double-check that repos with non-standard component subfolders (e.g., text_encoder/, vision_tower/) are handled upstream by selecting the correct base path; this helper only handles top-level files.


140-146: Good: shard discovery + non-model copy before merge

Flow is sound: discover shards → copy non-model artifacts → iterate shards. Nice.


118-118: Nice: logging scale factor aids reproducibility

Good observability touch.


214-215: Good final summary log

Simple, actionable merge summary.

@ved1beta
Copy link
Contributor Author

benchmarks coming soon

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (4)
src/axolotl/utils/lora_merge_efficient.py (3)

27-37: Fix suffix removal: rstrip(".weight") is incorrect and can corrupt keys

rstrip removes any trailing characters present in the set {'.', 'w', 'e', 'i', 'g', 'h', 't'} rather than the exact suffix, which can break LoRA key matching and lead to silent mis-merges or missed merges.

Apply this diff to remove only the exact ".weight" suffix:

-    clean_key = key.rstrip(".weight")
+    # Remove only the exact ".weight" suffix
+    clean_key = key[:-7] if key.endswith(".weight") else key

43-55: Initialize list correctly; current list[Path]() raises at runtime

list[Path]() is a subscripted type hint, not a constructor. This will raise a TypeError at runtime.

Use a standard list initialization with an optional type hint:

-def get_model_shards(model_path: Path) -> list[Path]:
+def get_model_shards(model_path: Path) -> list[Path]:
     """Find all model shards in the given path."""
-    shards = list[Path]()
+    shards: list[Path] = []
 
     patterns = ["model*.safetensors", "pytorch_model*.bin"]

197-206: Preserve original shard format; don’t emit fake “.safetensors” from torch.save; ensure CPU tensors before saving

Renaming .bin to .safetensors while still using torch.save produces invalid safetensors files and breaks HF index mappings. Also, ensure tensors are on CPU when writing.

Apply this diff:

-        output_shard_path = output_path / shard_path.name
-        if safe_tensors and shard_path.suffix == ".safetensors":
-            safetensors.torch.save_file(
-                merged_tensors, output_shard_path, metadata=metadata
-            )
-        else:
-            if safe_tensors:
-                output_shard_path = output_shard_path.with_suffix(".safetensors")
-            torch.save(merged_tensors, output_shard_path)
+        output_shard_path = output_path / shard_path.name
+        # Ensure CPU tensors before writing
+        merged_tensors = {k: v.detach().cpu() for k, v in merged_tensors.items()}
+        if shard_path.suffix == ".safetensors":
+            safetensors.torch.save_file(merged_tensors, output_shard_path, metadata=metadata)
+        else:
+            if safe_tensors:
+                LOG.warning(
+                    "safe_tensors=True requested but base shards are .bin; preserving .bin format to keep index consistent."
+                )
+            torch.save(merged_tensors, output_shard_path)
src/axolotl/cli/merge_lora.py (1)

80-91: Pass device explicitly to support CPU-only hosts; avoid defaulting to CUDA

The helper defaults to "cuda". On CPU-only machines this will fail before merging, forcing an unnecessary fallback to legacy. Detect availability and pass an explicit device.

Apply this diff:

     LOG.info("Using memory-efficient LoRA merging method...")
 
     output_path = Path(cfg.output_dir) / "merged"
     safe_tensors = getattr(cfg, "save_safetensors", True)
 
+    # Select device: prefer CUDA if available, otherwise CPU
+    try:
+        import torch
+        device = "cuda" if torch.cuda.is_available() else "cpu"
+    except Exception:  # pragma: no cover
+        device = "cpu"
+
     # Perform memory-efficient merge
     merge_lora_sharded_efficient(
         base_model_path=cfg.base_model,
         lora_adapter_path=cfg.lora_model_dir,
         output_path=output_path,
-        safe_tensors=safe_tensors,
+        safe_tensors=safe_tensors,
+        device=device,
     )

Note: As per the team learning, we’re not tying this to lora_on_cpu (that flag is only relevant for full-model loading), just honoring hardware availability.

🧹 Nitpick comments (3)
src/axolotl/utils/lora_merge_efficient.py (2)

112-116: Guard against invalid LoRA configs (r == 0) and log context

Defensive check: if lora_config.r is 0, division fails and merge scale is undefined.

Apply this small guard:

-    scale = lora_config.lora_alpha / lora_config.r
+    if not getattr(lora_config, "r", None):
+        raise ValueError(f"Invalid LoRA config: r={getattr(lora_config, 'r', None)}")
+    scale = lora_config.lora_alpha / lora_config.r

156-160: Optional micro-optimization: short-circuit once both A and B found

Early-exiting the loop when both weights are found avoids iterating the entire lora_state on every tensor key (benefits large adapters).

Example:

     for lora_key, lora_weight in lora_state.items():
         if lora_key.endswith(f"{clean_key}.lora_A.weight"):
             lora_a = lora_weight
         elif lora_key.endswith(f"{clean_key}.lora_B.weight"):
             lora_b = lora_weight
+        if lora_a is not None and lora_b is not None:
+            break
src/axolotl/cli/merge_lora.py (1)

24-36: Broaden fallback and clarify user guidance

Catching only RuntimeError may miss common failure modes (e.g., FileNotFoundError for missing shards/config). Consider broadening the exception set and keep the clear fallback message.

Example:

-    else:
-        try:
-            _do_merge_lora_efficient(cfg=cfg)
-        except RuntimeError as e:
+    else:
+        try:
+            _do_merge_lora_efficient(cfg=cfg)
+        except (RuntimeError, FileNotFoundError, OSError, ValueError) as e:
             LOG.error(f"Memory-efficient merge failed: {e}")
             LOG.info("Falling back to legacy merge method...")
             _do_merge_lora_legacy(cfg=cfg)
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 4ccf5ae and f094ea2.

📒 Files selected for processing (3)
  • src/axolotl/cli/merge_lora.py (3 hunks)
  • src/axolotl/utils/lora_merge_efficient.py (1 hunks)
  • src/axolotl/utils/schemas/peft.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/axolotl/utils/schemas/peft.py
🧰 Additional context used
🧠 Learnings (4)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.434Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.
📚 Learning: 2025-08-22T13:23:41.434Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.434Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.

Applied to files:

  • src/axolotl/cli/merge_lora.py
  • src/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.387Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.387Z
Learning: HuggingFace transformers uses standard patterns `pytorch_model*.bin` and `model*.safetensors` for model shards, as defined in transformers/utils/__init__.py. Additional patterns like `pytorch_model*.safetensors` are not necessary for standard HF model discovery.

Applied to files:

  • src/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.387Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.387Z
Learning: HuggingFace transformers uses these standard filename patterns: WEIGHTS_NAME = "pytorch_model.bin", SAFE_WEIGHTS_NAME = "model.safetensors" (not "pytorch_model.safetensors"), and sharded files follow "pytorch_model-*.bin" and "model-*.safetensors" patterns. The patterns "pytorch_model*.bin" and "model*.safetensors" are sufficient for discovering HF model shards.

Applied to files:

  • src/axolotl/utils/lora_merge_efficient.py
🧬 Code graph analysis (2)
src/axolotl/cli/merge_lora.py (3)
src/axolotl/utils/lora_merge_efficient.py (1)
  • merge_lora_sharded_efficient (86-211)
src/axolotl/utils/logging.py (1)
  • get_logger (42-49)
src/axolotl/utils/dict.py (1)
  • DictDefault (6-38)
src/axolotl/utils/lora_merge_efficient.py (1)
src/axolotl/utils/logging.py (1)
  • get_logger (42-49)
🪛 Ruff (0.12.2)
src/axolotl/utils/lora_merge_efficient.py

156-156: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: preview
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: pre-commit
🔇 Additional comments (2)
src/axolotl/utils/lora_merge_efficient.py (2)

151-178: Avoid VRAM growth: store CPU tensors in merged_tensors during safetensors path

Currently, tensors are read on device (often CUDA) and kept on GPU inside merged_tensors until the shard is saved, causing peak VRAM to scale with shard size. Convert to CPU before adding to merged_tensors to keep GPU usage bounded to a single tensor plus LoRA weights.
[uggest_essential_refactor]
Apply this minimal-change diff:

-                        merged_tensor = (tensor_fp32 + delta).to(original_dtype)
-                        merged_tensors[key] = merged_tensor
+                        merged_tensors[key] = (tensor_fp32 + delta).to(original_dtype).detach().cpu()
                     else:
-                        merged_tensors[key] = tensor
+                        merged_tensors[key] = tensor.detach().cpu()

Note: Keeping results on CPU preserves memory-efficiency without changing compute placement. If desired, we can also open safetensors with device="cpu" and temporarily move tensors to device only for the addition; happy to provide that variant.


156-156: Ignore Ruff SIM118 hint here

The linter hint about using membership on dict instead of .keys() is not applicable; f is a safetensors.safe_open handle, not a dict. Iterating f.keys() is the intended API.

@SalmanMohammadi
Copy link
Contributor

@ved1beta could you also ensure the weights/logits produced by a model which was merged using the legacy vs. memory efficient method are identical?

@ved1beta
Copy link
Contributor Author

this should be insured with the test run ?

tested with examples/llama-3/qlora-1b.yml
with tiny llama 1 b instruct and merge_methode:memory efficient

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

♻️ Duplicate comments (4)
src/axolotl/utils/lora_merge_efficient.py (4)

115-115: Nice: Helpful logging added

Scale factor, adapter path, and shard count logs improve debuggability and UX.

Also applies to: 125-126, 141-143


27-37: Bug: rstrip(".weight") is incorrect and corrupts keys

rstrip(".weight") removes any trailing characters in the set {'.', 'w', 'e', 'i', 'g', 'h', 't'}, not the exact suffix. This will cause mismatched/missed LoRA key lookups.

Apply this minimal fix:

-    clean_key = key.rstrip(".weight")
+    # Remove only the exact ".weight" suffix
+    clean_key = key[:-7] if key.endswith(".weight") else key

43-55: Runtime error: list[Path]() is not a constructor

This raises at runtime. Initialize a list instead.

-    shards = list[Path]()
+    shards: list[Path] = []

197-206: Critical: do not rename .bin shards to .safetensors; ensure CPU tensors before saving

Renaming .bin to .safetensors while calling torch.save produces invalid safetensors and breaks HF index files. Preserve the original shard extension and only use safetensors.torch.save_file for .safetensors inputs. Always save CPU tensors.

-        output_shard_path = output_path / shard_path.name
-        if safe_tensors and shard_path.suffix == ".safetensors":
-            safetensors.torch.save_file(
-                merged_tensors, output_shard_path, metadata=metadata
-            )
-        else:
-            if safe_tensors:
-                output_shard_path = output_shard_path.with_suffix(".safetensors")
-            torch.save(merged_tensors, output_shard_path)
+        output_shard_path = output_path / shard_path.name
+        # Ensure CPU tensors before writing
+        merged_tensors = {k: v.detach().cpu() for k, v in merged_tensors.items()}
+        if shard_path.suffix == ".safetensors":
+            safetensors.torch.save_file(merged_tensors, output_shard_path, metadata=metadata)
+        else:
+            # Preserve .bin format to keep HF indices valid
+            torch.save(merged_tensors, output_shard_path)
🧹 Nitpick comments (2)
src/axolotl/utils/lora_merge_efficient.py (2)

21-41: Avoid O(N×M) scans of lora_state for every tensor

Scanning the entire lora_state per tensor is quadratic and slow on large models. Pre-index LoRA A/B weights once, then do O(1) lookups.

Example approach (new helper and usage):

# New helper (place near find_lora_weights)
def build_lora_index(lora_state: Dict[str, torch.Tensor]) -> dict[str, tuple[torch.Tensor|None, torch.Tensor|None]]:
    index: dict[str, tuple[Optional[torch.Tensor], Optional[torch.Tensor]]] = {}
    for k, v in lora_state.items():
        if k.endswith(".lora_A.weight"):
            base = k[:-len(".lora_A.weight")]
            a, b = index.get(base, (None, None))
            index[base] = (v, b)
        elif k.endswith(".lora_B.weight"):
            base = k[:-len(".lora_B.weight")]
            a, b = index.get(base, (None, None))
            index[base] = (a, v)
    return index

Then replace find_lora_weights(lora_state, key) with lookups like:

base = key[:-7] if key.endswith(".weight") else key
lora_a, lora_b = lora_index.get(base, (None, None))

90-92: Default device "cuda" is risky for a memory-efficient path

Defaulting to GPU can surprise users and increase VRAM usage. Consider defaulting to "cpu" and allowing callers to opt-in to a GPU device.

-    device: str = "cuda",
+    device: str = "cpu",
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between f094ea2 and fe157bd.

📒 Files selected for processing (1)
  • src/axolotl/utils/lora_merge_efficient.py (1 hunks)
🧰 Additional context used
🧠 Learnings (4)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.434Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.
📚 Learning: 2025-08-22T13:23:41.434Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.434Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.

Applied to files:

  • src/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.387Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.387Z
Learning: HuggingFace transformers uses standard patterns `pytorch_model*.bin` and `model*.safetensors` for model shards, as defined in transformers/utils/__init__.py. Additional patterns like `pytorch_model*.safetensors` are not necessary for standard HF model discovery.

Applied to files:

  • src/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.387Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.387Z
Learning: HuggingFace transformers uses these standard filename patterns: WEIGHTS_NAME = "pytorch_model.bin", SAFE_WEIGHTS_NAME = "model.safetensors" (not "pytorch_model.safetensors"), and sharded files follow "pytorch_model-*.bin" and "model-*.safetensors" patterns. The patterns "pytorch_model*.bin" and "model*.safetensors" are sufficient for discovering HF model shards.

Applied to files:

  • src/axolotl/utils/lora_merge_efficient.py
🧬 Code graph analysis (1)
src/axolotl/utils/lora_merge_efficient.py (2)
src/axolotl/utils/schemas/peft.py (1)
  • LoraConfig (28-191)
src/axolotl/utils/logging.py (1)
  • get_logger (42-49)
🪛 Ruff (0.12.2)
src/axolotl/utils/lora_merge_efficient.py

156-156: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: pre-commit
  • GitHub Check: pre-commit
  • GitHub Check: preview
🔇 Additional comments (2)
src/axolotl/utils/lora_merge_efficient.py (2)

47-53: Patterns for HF shards look correct

Using "model*.safetensors" and "pytorch_model*.bin" aligns with HF conventions and will discover both single-file and sharded checkpoints.


68-84: Validation complete – all index JSON references are valid

  • Ran the provided validation script against both model.safetensors.index.json and pytorch_model.bin.index.json in merged_out; no missing shard files were reported.
  • The copy logic in lora_merge_efficient.py (lines 68–84) correctly skips model shards and .gguf files while preserving all other artifacts (e.g., tokenizer and config files).

With index consistency confirmed, no further changes are needed here.

Comment on lines 112 to 115
lora_config = LoraConfig.from_json_file(config_file)
scale = lora_config["lora_alpha"] / lora_config["r"]

LOG.info(f"LoRA scale factor: {scale}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Bug: LoraConfig fields accessed like a dict

LoraConfig.from_json_file returns an object; lora_config["lora_alpha"] will fail. Use attributes. Also guard against zero r.

-    lora_config = LoraConfig.from_json_file(config_file)
-    scale = lora_config["lora_alpha"] / lora_config["r"]
+    lora_config = LoraConfig.from_json_file(config_file)
+    if not getattr(lora_config, "r", None):
+        raise ValueError("LoRA config 'r' must be > 0")
+    scale = float(lora_config.lora_alpha) / float(lora_config.r)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
lora_config = LoraConfig.from_json_file(config_file)
scale = lora_config["lora_alpha"] / lora_config["r"]
LOG.info(f"LoRA scale factor: {scale}")
lora_config = LoraConfig.from_json_file(config_file)
# Ensure 'r' is present and non-zero to avoid division by zero
if not getattr(lora_config, "r", None):
raise ValueError("LoRA config 'r' must be > 0")
scale = float(lora_config.lora_alpha) / float(lora_config.r)
LOG.info(f"LoRA scale factor: {scale}")
🤖 Prompt for AI Agents
In src/axolotl/utils/lora_merge_efficient.py around lines 112 to 115, the code
treats LoraConfig as a dict (lora_config["lora_alpha"]) which will raise since
LoraConfig.from_json_file returns an object; change to use attribute access
(e.g., lora_config.lora_alpha and lora_config.r) and add a guard to prevent
division by zero (check if lora_config.r is truthy/greater than zero before
dividing, otherwise set scale to a safe default or raise a clear error), then
log the computed scale.

Comment on lines 170 to 175
delta = scale * (
lora_b.to(torch.float32) @ lora_a.to(torch.float32)
)

merged_tensor = (tensor_fp32 + delta).to(original_dtype)
merged_tensors[key] = merged_tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Support fan_in_fan_out orientation from config

If lora_config.fan_in_fan_out (or lora_fan_in_fan_out) is enabled, a transpose is required to match PEFT’s application. Current code ignores this and may produce incorrect merges for such adapters.

Minimal patch sketch inside both safetensors and .bin branches:

-                        delta = scale * (b_fp32 @ a_fp32)
+                        delta = scale * (b_fp32 @ a_fp32)
+                        if getattr(lora_config, "fan_in_fan_out", False) or getattr(lora_config, "lora_fan_in_fan_out", False):
+                            delta = delta.T

Please confirm the exact field name used by your PEFT version and add tests for a fan-in-fan-out adapter.

Also applies to: 188-194


🏁 Script executed:

#!/bin/bash
set -e

# Print context around the LoRA merge logic (lines ~150–220)
sed -n '150,220p' src/axolotl/utils/lora_merge_efficient.py

# Locate any definitions or imports of LoraConfig
rg -n "class.*LoraConfig" -g src
rg -n "import.*LoraConfig" -g src

# Check for existing handling of fan_in_fan_out flags
rg -n "fan_in_fan_out" -g src
rg -n "lora_fan_in_fan_out" -g src || true

Length of output: 2842


🏁 Script executed:

#!/bin/bash
set -e

# Print file header and initial context to locate function signature and imports
sed -n '1,200p' src/axolotl/utils/lora_merge_efficient.py

# Search for any mention of fan_in_fan_out in the entire repository
rg -n "fan_in_fan_out" .

# Search for LoraConfig class or imports across the codebase
rg -n "LoraConfig" .

# Search for usage of lora_config within this file
rg -n "lora_config" -n src/axolotl/utils/lora_merge_efficient.py

Length of output: 8450


Support LoRA fan_in_fan_out orientation in merge

The current merge logic in src/axolotl/utils/lora_merge_efficient.py always computes

delta = scale * (lora_b.to(torch.float32) @ lora_a.to(torch.float32))

without accounting for the fan_in_fan_out flag in the PEFT config, which will result in incorrect merges when adapters were trained with fan_in_fan_out=True.

Please apply the following change in both the .safetensors branch (around lines 170–175) and the .bin branch (around lines 188–194):

-   delta = scale * (lora_b.to(torch.float32) @ lora_a.to(torch.float32))
+   delta = scale * (lora_b.to(torch.float32) @ lora_a.to(torch.float32))
+   if lora_config.fan_in_fan_out:
+       delta = delta.T

• Locations to update:

  • safetensors loop (after line 170)
  • torch.load loop (after line 188)

• Add a unit test with a LoRA adapter configured as fan_in_fan_out=True to verify the transpose is applied correctly.

🤖 Prompt for AI Agents
In src/axolotl/utils/lora_merge_efficient.py around lines 170–175 (safetensors
branch) and around lines 188–194 (torch.load/.bin branch), the merge always
computes delta as scale * (lora_b @ lora_a) and ignores the PEFT config flag
fan_in_fan_out; update both locations to check the adapter config and, when
fan_in_fan_out is True, transpose lora_a and lora_b appropriately (e.g., swap or
transpose operands so multiplication reflects the trained orientation) before
computing delta, then cast back to original dtype as now; also add a unit test
that loads/creates a LoRA adapter with fan_in_fan_out=True, runs the merge, and
asserts the merged tensor matches the expected result when the transpose branch
is applied.

@ved1beta ved1beta requested a review from winglian August 23, 2025 05:15
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (8)
src/axolotl/utils/lora_merge_efficient.py (7)

27-40: Fix incorrect suffix removal and enforce exact LoRA key matching

rstrip(".weight") removes any trailing characters in the set {".weight"} and can corrupt keys; use an exact suffix check. Also keep precise endswith patterns for A/B weights.

-    clean_key = key.rstrip(".weight")
+    # Remove only the exact ".weight" suffix
+    clean_key = key[:-7] if key.endswith(".weight") else key

-    for lora_key, lora_weight in lora_state.items():
-        if lora_key.endswith(f"{clean_key}.lora_A.weight"):
-            lora_a = lora_weight
-        elif lora_key.endswith(f"{clean_key}.lora_B.weight"):
-            lora_b = lora_weight
+    for lora_key, lora_weight in lora_state.items():
+        if lora_key.endswith(f"{clean_key}.lora_A.weight"):
+            lora_a = lora_weight
+        elif lora_key.endswith(f"{clean_key}.lora_B.weight"):
+            lora_b = lora_weight

Also applies to: 32-37


43-55: Initialize shards list correctly; fix runtime error

list[Path]() is a type subscription, not a constructor; it will throw at runtime.

-    shards = list[Path]()
+    shards: list[Path] = []

112-116: Access LoraConfig fields via attributes and guard division by zero

LoraConfig.from_json_file returns an object; indexing like a dict will fail. Also validate r > 0.

-    lora_config = LoraConfig.from_json_file(config_file)
-    scale = lora_config["lora_alpha"] / lora_config["r"]
+    lora_config = LoraConfig.from_json_file(config_file)
+    if not getattr(lora_config, "r", None):
+        raise ValueError("LoRA config 'r' must be > 0")
+    scale = float(lora_config.lora_alpha) / float(lora_config.r)

132-136: Avoid VRAM spikes: don’t bulk-move all LoRA tensors to GPU

Moving the entire LoRA state to GPU defeats the “memory-efficient” goal and can OOM on small GPUs. Keep LoRA tensors on CPU and move per-tensor during merge.

-    if device != "cpu":
-        LOG.debug(f"Moving LoRA weights to {device}")
-        for key, value in tqdm(lora_state.items(), desc="Moving LoRA to device"):
-            lora_state[key] = value.to(device)
+    LOG.debug("Keeping LoRA weights on CPU; will move per-tensor during merge")

151-176: Load safetensors on CPU; per-tensor compute on device; support fan_in_fan_out; store results on CPU

  • safe_open(..., device=device) may load tensors directly to GPU; use CPU and JIT-move for compute.
  • Honor fan_in_fan_out orientation when present in the config.
  • Ensure merged tensors are on CPU before serialization.
-        if shard_path.suffix == ".safetensors":
-            with safetensors.safe_open(shard_path, framework="pt", device=device) as f:
+        if shard_path.suffix == ".safetensors":
+            # Always open on CPU to minimize VRAM; move per-tensor as needed
+            with safetensors.safe_open(shard_path, framework="pt", device="cpu") as f:
                 if hasattr(f, "metadata") and f.metadata():
                     metadata = f.metadata()
@@
-                for key in f.keys():
+                for key in f.keys():
                     total_tensors += 1
-                    tensor = f.get_tensor(key)
+                    tensor = f.get_tensor(key)  # CPU tensor
                     lora_a, lora_b = find_lora_weights(lora_state, key)
@@
-                    if lora_a is not None and lora_b is not None:
+                    if lora_a is not None and lora_b is not None:
                         merged_count += 1
@@
-                        original_dtype = tensor.dtype
-                        tensor_fp32 = tensor.to(torch.float32)
-
-                        delta = scale * (
-                            lora_b.to(torch.float32) @ lora_a.to(torch.float32)
-                        )
-
-                        merged_tensor = (tensor_fp32 + delta).to(original_dtype)
-                        merged_tensors[key] = merged_tensor
+                        original_dtype = tensor.dtype
+                        base_fp32 = tensor.to(device).to(torch.float32)
+                        a_fp32 = lora_a.to(device).to(torch.float32)
+                        b_fp32 = lora_b.to(device).to(torch.float32)
+                        delta = scale * (b_fp32 @ a_fp32)
+                        if getattr(lora_config, "fan_in_fan_out", False) or getattr(lora_config, "lora_fan_in_fan_out", False):
+                            delta = delta.T
+                        merged_tensors[key] = (base_fp32 + delta).to(original_dtype).detach().cpu()
                     else:
-                        merged_tensors[key] = tensor
+                        merged_tensors[key] = tensor.detach().cpu()

179-196: Load .bin shards on CPU; compute on device; store results on CPU

Load state dict on CPU with weights_only=True and JIT-move to device for compute to avoid unnecessary VRAM usage.

-        else:
-            state_dict = torch.load(
-                shard_path, map_location=device
-            )  # nosec B614: loading trusted model weights
+        else:
+            state_dict = torch.load(  # nosec B614: loading trusted model weights
+                shard_path, map_location="cpu", weights_only=True
+            )
             for key, tensor in state_dict.items():
                 total_tensors += 1
                 lora_a, lora_b = find_lora_weights(lora_state, key)
@@
-                if lora_a is not None and lora_b is not None:
+                if lora_a is not None and lora_b is not None:
                     merged_count += 1
                     original_dtype = tensor.dtype
-                    tensor_fp32 = tensor.to(torch.float32)
-                    delta = scale * (
-                        lora_b.to(torch.float32) @ lora_a.to(torch.float32)
-                    )
-                    merged_tensors[key] = (tensor_fp32 + delta).to(original_dtype)
+                    base_fp32 = tensor.to(device).to(torch.float32)
+                    a_fp32 = lora_a.to(device).to(torch.float32)
+                    b_fp32 = lora_b.to(device).to(torch.float32)
+                    delta = scale * (b_fp32 @ a_fp32)
+                    if getattr(lora_config, "fan_in_fan_out", False) or getattr(lora_config, "lora_fan_in_fan_out", False):
+                        delta = delta.T
+                    merged_tensors[key] = (base_fp32 + delta).to(original_dtype).detach().cpu()
                 else:
-                    merged_tensors[key] = tensor
+                    merged_tensors[key] = tensor.detach().cpu()

197-206: Do not rename .bin shards to .safetensors; ensure CPU tensors before writing

Renaming .bin to .safetensors while calling torch.save produces invalid safetensors and breaks HF index files. Preserve the original shard format; always write CPU tensors; attach metadata only for safetensors.

-        output_shard_path = output_path / shard_path.name
-        if safe_tensors and shard_path.suffix == ".safetensors":
-            safetensors.torch.save_file(
-                merged_tensors, output_shard_path, metadata=metadata
-            )
-        else:
-            if safe_tensors:
-                output_shard_path = output_shard_path.with_suffix(".safetensors")
-            torch.save(merged_tensors, output_shard_path)
+        output_shard_path = output_path / shard_path.name
+        # Ensure CPU tensors for serialization
+        merged_tensors = {k: v.detach().cpu() for k, v in merged_tensors.items()}
+        if shard_path.suffix == ".safetensors":
+            safetensors.torch.save_file(merged_tensors, output_shard_path, metadata=metadata)
+        else:
+            if safe_tensors:
+                LOG.warning(
+                    "safe_tensors=True requested but input shards are .bin; preserving .bin format to avoid index mismatches."
+                )
+            torch.save(merged_tensors, output_shard_path)
src/axolotl/cli/merge_lora.py (1)

80-91: Pass device explicitly to support CPU-only hosts

The efficient helper defaults to "cuda". On CPU-only machines this will raise. Detect CUDA availability and pass a device argument.

-    # Perform memory-efficient merge
-    merge_lora_sharded_efficient(
+    # Choose device: prefer CUDA if available, otherwise CPU
+    try:
+        import torch
+        has_cuda = torch.cuda.is_available()
+    except Exception:
+        has_cuda = False
+    device = "cuda" if has_cuda else "cpu"
+
+    # Perform memory-efficient merge
+    merge_lora_sharded_efficient(
         base_model_path=cfg.base_model,
         lora_adapter_path=cfg.lora_model_dir,
         output_path=output_path,
         safe_tensors=safe_tensors,
+        device=device,
     )
🧹 Nitpick comments (1)
src/axolotl/utils/lora_merge_efficient.py (1)

68-84: Optional: copy full directory tree (minus model shards) to preserve ancillary assets

Current implementation only copies top-level files. Some models include assets in subfolders (e.g., processors). Consider a recursive copy that excludes shard files by name.

Happy to propose a filtered copy routine using shutil.copytree with an ignore function that excludes shard files and GGUF.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between fe157bd and d63de30.

📒 Files selected for processing (2)
  • src/axolotl/cli/merge_lora.py (3 hunks)
  • src/axolotl/utils/lora_merge_efficient.py (1 hunks)
🧰 Additional context used
🧠 Learnings (4)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.434Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.
📚 Learning: 2025-08-22T13:23:41.434Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.434Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.

Applied to files:

  • src/axolotl/cli/merge_lora.py
  • src/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.387Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.387Z
Learning: HuggingFace transformers uses standard patterns `pytorch_model*.bin` and `model*.safetensors` for model shards, as defined in transformers/utils/__init__.py. Additional patterns like `pytorch_model*.safetensors` are not necessary for standard HF model discovery.

Applied to files:

  • src/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.387Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.387Z
Learning: HuggingFace transformers uses these standard filename patterns: WEIGHTS_NAME = "pytorch_model.bin", SAFE_WEIGHTS_NAME = "model.safetensors" (not "pytorch_model.safetensors"), and sharded files follow "pytorch_model-*.bin" and "model-*.safetensors" patterns. The patterns "pytorch_model*.bin" and "model*.safetensors" are sufficient for discovering HF model shards.

Applied to files:

  • src/axolotl/utils/lora_merge_efficient.py
🧬 Code graph analysis (2)
src/axolotl/cli/merge_lora.py (3)
src/axolotl/utils/lora_merge_efficient.py (1)
  • merge_lora_sharded_efficient (86-211)
src/axolotl/utils/logging.py (1)
  • get_logger (42-49)
src/axolotl/utils/dict.py (1)
  • DictDefault (6-38)
src/axolotl/utils/lora_merge_efficient.py (1)
src/axolotl/utils/logging.py (1)
  • get_logger (42-49)
🪛 Ruff (0.12.2)
src/axolotl/utils/lora_merge_efficient.py

156-156: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: preview
  • GitHub Check: pre-commit
🔇 Additional comments (4)
src/axolotl/utils/lora_merge_efficient.py (2)

141-146: Progress counters are helpful; consider logging merge ratio at end (already present)

The final summary log looks good and aligns with expectations for shard-wise merging.


211-211: Nice finishing log

Clear, actionable summary for users. LGTM.

src/axolotl/cli/merge_lora.py (2)

38-71: Legacy path looks consistent

Loading, merging, dtype cast guard, and saves are in line with the existing behavior. Minor logs are helpful. LGTM.


123-128: Helpful validation for lora_model_dir

Clear error message; good UX.

@winglian
Copy link
Collaborator

Were you able to train a lora, and then merge using both the legacy and memory efficient methods to verify identical merged weights from both methods?

@ved1beta
Copy link
Contributor Author

ved1beta commented Aug 29, 2025

yes tried merging and everything as you mentioned earlier here is the taining output [slacklink]( https://ai-axolotl.slack.com/files/U09BE3G7ZED/F09BNKLDDNZ/untitled?origin_team=T05A3MTMVB8&origin_channel=D09BE3HMM7B )

i have a claude generated script for testing identical model weights it passes for the given checkpoint generated from the training

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (6)
src/axolotl/utils/lora_merge_efficient.py (5)

27-36: Fix suffix removal: rstrip is incorrect; use exact suffix logic

rstrip(".weight") removes any of those chars from the right end and can corrupt keys, breaking LoRA key matching.

Apply:

-    clean_key = key.rstrip(".weight")
+    # Remove only the exact ".weight" suffix
+    clean_key = key[:-7] if key.endswith(".weight") else key

43-49: Initialize list correctly; current code raises at runtime

list[Path]() is not a constructor call.

Apply:

-    shards = list[Path]()
+    shards: list[Path] = []

112-116: Access LoraConfig attributes (not dict); guard divide-by-zero

LoraConfig.from_json_file returns an object; dict-style access will fail. Also handle r == 0/None.

Apply:

-    lora_config = LoraConfig.from_json_file(config_file)
-    scale = lora_config["lora_alpha"] / lora_config["r"]
+    lora_config = LoraConfig.from_json_file(config_file)
+    if not getattr(lora_config, "r", None):
+        raise ValueError("LoRA config 'r' must be > 0")
+    scale = float(lora_config.lora_alpha) / float(lora_config.r)

198-206: Do not write PyTorch pickles with .safetensors extension; preserve original format

Current logic renames .bin shards to .safetensors but still uses torch.save, corrupting outputs and breaking index files.

Apply:

-        output_shard_path = output_path / shard_path.name
-        if safe_tensors and shard_path.suffix == ".safetensors":
-            safetensors.torch.save_file(
-                merged_tensors, output_shard_path, metadata=metadata
-            )
-        else:
-            if safe_tensors:
-                output_shard_path = output_shard_path.with_suffix(".safetensors")
-            torch.save(merged_tensors, output_shard_path)
+        output_shard_path = output_path / shard_path.name
+        # Ensure CPU tensors before writing
+        merged_tensors = {
+            k: (v.detach().cpu() if isinstance(v, torch.Tensor) else v)
+            for k, v in merged_tensors.items()
+        }
+        if shard_path.suffix == ".safetensors":
+            safetensors.torch.save_file(merged_tensors, output_shard_path, metadata=metadata)
+        else:
+            if safe_tensors:
+                LOG.warning(
+                    "safe_tensors=True requested but input shards are .bin; preserving .bin format to avoid index mismatches."
+                )
+            torch.save(merged_tensors, output_shard_path)

176-179: Safer deserialization: use weights_only=True and keep CPU

Avoids unnecessary pickle deserialization semantics.

Apply:

-            state_dict = torch.load(
-                shard_path, map_location="cpu"
-            )  # nosec B614: loading trusted model weights
+            state_dict = torch.load(  # nosec B614: loading trusted model weights
+                shard_path, map_location="cpu", weights_only=True
+            )
src/axolotl/cli/merge_lora.py (1)

24-36: Normalize merge_method; accept standard; broaden fallback and log exception

Improves UX (PR text mentions “standard”) and ensures fallback triggers on any failure. Also addresses TRY400.

Apply:

-    merge_method = getattr(cfg, "merge_method", "memory_efficient")
-    LOG.info(f"Using {merge_method} LoRA merge method")
-
-    if merge_method == "legacy":
-        _do_merge_lora_legacy(cfg=cfg)
-    else:
-        try:
-            _do_merge_lora_efficient(cfg=cfg)
-        except RuntimeError as e:
-            LOG.error(f"Memory-efficient merge failed: {e}")
-            LOG.info("Falling back to legacy merge method...")
-            _do_merge_lora_legacy(cfg=cfg)
+    merge_method = str(getattr(cfg, "merge_method", "memory_efficient")).lower().replace("-", "_")
+    if merge_method in {"legacy", "standard"}:
+        LOG.info("Using legacy LoRA merge method...")
+        _do_merge_lora_legacy(cfg=cfg)
+    else:
+        LOG.info("Using memory-efficient LoRA merge method...")
+        try:
+            _do_merge_lora_efficient(cfg=cfg)
+        except Exception:
+            LOG.exception("Memory-efficient merge failed")
+            LOG.info("Falling back to legacy merge method...")
+            _do_merge_lora_legacy(cfg=cfg)
🧹 Nitpick comments (2)
src/axolotl/utils/lora_merge_efficient.py (1)

86-92: device parameter is unused; either remove it or use it per-tensor

Currently flagged by Ruff (ARG001). Since merges run on CPU for memory efficiency, simplest is to drop it.

Apply (and update call sites):

-def merge_lora_sharded_efficient(
+def merge_lora_sharded_efficient(
     base_model_path: Union[str, Path],
     lora_adapter_path: Union[str, Path],
     output_path: Union[str, Path],
-    device: str = "cpu",
     safe_tensors: bool = True,
 ) -> None:
src/axolotl/cli/merge_lora.py (1)

86-92: If dropping device in the helper, remove it here too

Keeps signatures consistent.

-    merge_lora_sharded_efficient(
+    merge_lora_sharded_efficient(
         base_model_path=cfg.base_model,
         lora_adapter_path=cfg.lora_model_dir,
         output_path=output_path,
-        device="cpu",
         safe_tensors=safe_tensors,
     )
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between d63de30 and 479e914.

📒 Files selected for processing (2)
  • src/axolotl/cli/merge_lora.py (3 hunks)
  • src/axolotl/utils/lora_merge_efficient.py (1 hunks)
🧰 Additional context used
🧠 Learnings (4)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.455Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.
📚 Learning: 2025-08-22T13:23:41.455Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.455Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.

Applied to files:

  • src/axolotl/cli/merge_lora.py
  • src/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.411Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.411Z
Learning: HuggingFace transformers uses standard patterns `pytorch_model*.bin` and `model*.safetensors` for model shards, as defined in transformers/utils/__init__.py. Additional patterns like `pytorch_model*.safetensors` are not necessary for standard HF model discovery.

Applied to files:

  • src/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.411Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.411Z
Learning: HuggingFace transformers uses these standard filename patterns: WEIGHTS_NAME = "pytorch_model.bin", SAFE_WEIGHTS_NAME = "model.safetensors" (not "pytorch_model.safetensors"), and sharded files follow "pytorch_model-*.bin" and "model-*.safetensors" patterns. The patterns "pytorch_model*.bin" and "model*.safetensors" are sufficient for discovering HF model shards.

Applied to files:

  • src/axolotl/utils/lora_merge_efficient.py
🧬 Code graph analysis (2)
src/axolotl/cli/merge_lora.py (3)
src/axolotl/utils/lora_merge_efficient.py (1)
  • merge_lora_sharded_efficient (86-213)
src/axolotl/utils/logging.py (1)
  • get_logger (42-49)
src/axolotl/utils/dict.py (1)
  • DictDefault (6-38)
src/axolotl/utils/lora_merge_efficient.py (2)
src/axolotl/utils/schemas/peft.py (1)
  • LoraConfig (28-191)
src/axolotl/utils/logging.py (1)
  • get_logger (42-49)
🪛 Ruff (0.12.2)
src/axolotl/cli/merge_lora.py

33-33: Use logging.exception instead of logging.error

Replace with exception

(TRY400)

src/axolotl/utils/lora_merge_efficient.py

90-90: Unused function argument: device

(ARG001)


110-110: Avoid specifying long messages outside the exception class

(TRY003)


121-123: Avoid specifying long messages outside the exception class

(TRY003)


136-136: Avoid specifying long messages outside the exception class

(TRY003)

🪛 GitHub Actions: lint
src/axolotl/utils/lora_merge_efficient.py

[error] ruff-format reformatted 1 file during pre-commit; the hook failed in CI (exit code 1). Run 'pre-commit run --all-files' locally to apply formatting changes.

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.8.0)
  • GitHub Check: preview
🔇 Additional comments (2)
src/axolotl/utils/lora_merge_efficient.py (1)

85-96: Install dependencies and verify merge outputs match
Install torch and safetensors, then set LEGACY_DIR/EFFICIENT_DIR and run the provided equivalence‐check script to confirm all tensor keys and values match exactly.

src/axolotl/cli/merge_lora.py (1)

17-23: Add a --validate-merge flag to compare both merge paths
Implement an optional --validate-merge (or --self-check) switch in do_merge_lora/do_cli that, when enabled, runs both the legacy and memory_efficient merges (into temp dirs or in-memory), loads all output tensors via safetensors.torch and torch.load, verifies the key sets match, and uses torch.allclose(atol=0, rtol=0) on each tensor—erroring or warning on any mismatch. This ensures users can confirm both implementations produce identical merged weights.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (6)
src/axolotl/utils/lora_merge_efficient.py (6)

27-36: Bug: rstrip(".weight") corrupts keys; remove exact suffix instead

rstrip removes any trailing chars from the set, not the exact substring. This can mis-match LoRA keys.

Apply:

-    clean_key = key.rstrip(".weight")
+    # Remove only the exact ".weight" suffix
+    clean_key = key[:-7] if key.endswith(".weight") else key

43-55: Bug: listPath is not a constructor

This raises at runtime. Initialize normally.

-    shards = list[Path]()
+    shards: list[Path] = []

197-205: Critical: don’t rename .bin→.safetensors or write pickles with .safetensors extension

This corrupts outputs and breaks HF index references. Preserve original shard format; ensure CPU tensors before writing.

-        output_shard_path = output_path / shard_path.name
-        if safe_tensors and shard_path.suffix == ".safetensors":
-            safetensors.torch.save_file(
-                merged_tensors, output_shard_path, metadata=metadata
-            )
-        else:
-            if safe_tensors:
-                output_shard_path = output_shard_path.with_suffix(".safetensors")
-            torch.save(merged_tensors, output_shard_path)
+        output_shard_path = output_path / shard_path.name
+        # Ensure CPU tensors before writing
+        merged_tensors = {k: v.detach().cpu() for k, v in merged_tensors.items()}
+        if shard_path.suffix == ".safetensors":
+            safetensors.torch.save_file(merged_tensors, output_shard_path, metadata=metadata)
+        else:
+            if safe_tensors:
+                LOG.warning(
+                    "safe_tensors=True requested but input shards are .bin; preserving .bin format "
+                    "to avoid index mismatches. Consider a separate convert step."
+                )
+            torch.save(merged_tensors, output_shard_path)

182-193: Correctness: handle fan_in_fan_out in .bin branch too

Mirror the transpose logic here.

-                    delta = scale * (lora_b_fp32 @ lora_a_fp32)
+                    delta = scale * (lora_b_fp32 @ lora_a_fp32)
+                    if bool(getattr(lora_config, "fan_in_fan_out", False) or getattr(lora_config, "lora_fan_in_fan_out", False)):
+                        delta = delta.T

111-114: Bug: LoraConfig used like dict; add zero-division guard

from_json_file returns an object. Also guard r>0.

-    lora_config = LoraConfig.from_json_file(config_file)
-    scale = lora_config["lora_alpha"] / lora_config["r"]
+    lora_config = LoraConfig.from_json_file(config_file)
+    if not getattr(lora_config, "r", None) or lora_config.r <= 0:
+        raise ValueError("LoRA config 'r' must be > 0")
+    scale = float(lora_config.lora_alpha) / float(lora_config.r)

147-174: Correctness: handle fan_in_fan_out orientation when merging safetensors shard

Adapters trained with fan_in_fan_out=True require a transpose.

-                        delta = scale * (lora_b_fp32 @ lora_a_fp32)
+                        delta = scale * (lora_b_fp32 @ lora_a_fp32)
+                        if bool(getattr(lora_config, "fan_in_fan_out", False) or getattr(lora_config, "lora_fan_in_fan_out", False)):
+                            delta = delta.T
🧹 Nitpick comments (4)
src/axolotl/utils/lora_merge_efficient.py (4)

21-41: Speed up lookups by pre-indexing LoRA A/B once

Current O(N_base * N_lora) scan per tensor is avoidable. Build a suffix→(A,B) index once.

Option sketch (new helper outside diff for context):

def build_lora_index(lora_state: dict[str, torch.Tensor]) -> dict[str, tuple[torch.Tensor, torch.Tensor]]:
    a_map, b_map = {}, {}
    for k, v in lora_state.items():
        if k.endswith(".lora_A.weight"):
            a_map[k[: -len(".lora_A.weight")]] = v
        elif k.endswith(".lora_B.weight"):
            b_map[k[: -len(".lora_B.weight")]] = v
    return {k: (a_map[k], b_map[k]) for k in a_map.keys() & b_map.keys()}

Then in the merge loop, resolve with lora_index.get(clean_key).


82-84: Preserve file metadata when copying

Use copy2 to retain mtime/permissions.

-        shutil.copy(filepath, output_path)
+        shutil.copy2(filepath, output_path)

175-181: Safety/efficiency: load .bin shards weights-only on CPU

Prefer weights_only=True to avoid executing pickles; we only need tensors.

-            state_dict = torch.load(
-                shard_path, map_location="cpu"
-            )  # nosec B614: loading trusted model weights
+            state_dict = torch.load(  # nosec B614: loading trusted model weights
+                shard_path, map_location="cpu", weights_only=True
+            )

143-174: Optional: keep tensors on CPU but detach/cpu all merged values explicitly

Be explicit to avoid accidental device retention if future changes introduce GPU ops.

-                        merged_tensors[key] = merged_tensor
+                        merged_tensors[key] = merged_tensor.detach().cpu()
@@
-                    merged_tensors[key] = merged_tensor
+                    merged_tensors[key] = merged_tensor.detach().cpu()

Also applies to: 175-196

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 479e914 and 7145972.

📒 Files selected for processing (1)
  • src/axolotl/utils/lora_merge_efficient.py (1 hunks)
🧰 Additional context used
🧠 Learnings (4)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.455Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.
📚 Learning: 2025-08-22T13:23:41.455Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.455Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.

Applied to files:

  • src/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.411Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.411Z
Learning: HuggingFace transformers uses standard patterns `pytorch_model*.bin` and `model*.safetensors` for model shards, as defined in transformers/utils/__init__.py. Additional patterns like `pytorch_model*.safetensors` are not necessary for standard HF model discovery.

Applied to files:

  • src/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.411Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.411Z
Learning: HuggingFace transformers uses these standard filename patterns: WEIGHTS_NAME = "pytorch_model.bin", SAFE_WEIGHTS_NAME = "model.safetensors" (not "pytorch_model.safetensors"), and sharded files follow "pytorch_model-*.bin" and "model-*.safetensors" patterns. The patterns "pytorch_model*.bin" and "model*.safetensors" are sufficient for discovering HF model shards.

Applied to files:

  • src/axolotl/utils/lora_merge_efficient.py
🧬 Code graph analysis (1)
src/axolotl/utils/lora_merge_efficient.py (2)
src/axolotl/utils/schemas/peft.py (1)
  • LoraConfig (28-191)
src/axolotl/utils/logging.py (1)
  • get_logger (42-49)
🪛 Ruff (0.12.2)
src/axolotl/utils/lora_merge_efficient.py

109-109: Avoid specifying long messages outside the exception class

(TRY003)


120-122: Avoid specifying long messages outside the exception class

(TRY003)


135-135: Avoid specifying long messages outside the exception class

(TRY003)

🪛 GitHub Actions: lint
src/axolotl/utils/lora_merge_efficient.py

[error] 172-172: ruff-format: File reformatted by pre-commit; please re-run pre-commit or commit formatting changes.

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.8.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: preview
🔇 Additional comments (3)
src/axolotl/utils/lora_merge_efficient.py (3)

47-54: Shard patterns look good

Using model*.safetensors and pytorch_model*.bin matches HF conventions (per our prior learning).


140-174: Manual parity and VRAM profiling required
The automated script couldn’t run in this environment (missing torch), so please verify in your setup:

  • Compare all tensors in legacy_merged/ vs. memory_eff_merged/ byte-for-byte (using strict allclose for floats and equal for ints).
  • Profile peak GPU memory during both merge routines (e.g. nvidia-smi --query-gpu=memory.used --loop-ms=500) and report savings.
    Also apply the same checks to the code blocks at lines 175–196 and 197–205.

1-213: Install and run pre-commit hooks to apply ruff-format fixes
CI is failing on ruff-format; ensure you have pre-commit installed and run:

pre-commit install
pre-commit run --all-files

to reflow and commit the formatting changes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (8)
src/axolotl/utils/lora_merge_efficient.py (8)

27-27: Do not use rstrip(".weight") — it corrupts keys

rstrip(".weight") removes any trailing chars from the set {'.', 'w', 'e', 'i', 'g', 'h', 't'}, not the exact suffix. This can lead to false matches. Use explicit suffix removal.

-    clean_key = key.rstrip(".weight")
+    clean_key = key[:-7] if key.endswith(".weight") else key

43-55: Runtime error: list[Path]() is not a constructor

This will raise at runtime. Initialize as a standard list.

-    shards = list[Path]()
+    shards: list[Path] = []

112-115: LoraConfig is an object, not a dict; guard against zero r

LoraConfig.from_json_file returns an object; lora_config["..."] will fail. Also prevent division by zero.

-    lora_config = LoraConfig.from_json_file(config_file)
-    scale = lora_config["lora_alpha"] / lora_config["r"]
+    lora_config = LoraConfig.from_json_file(str(config_file))
+    if not getattr(lora_config, "r", None) or float(lora_config.r) <= 0:
+        raise ValueError("LoRA config 'r' must be > 0")
+    scale = float(getattr(lora_config, "lora_alpha", 1.0)) / float(lora_config.r)

132-136: Keep LoRA state on CPU; avoid bulk device transfer

Moving the entire adapter to GPU defeats the memory-efficient design and can spike VRAM.

-    if device != "cpu":
-        LOG.info(f"Moving LoRA weights to {device}")
-        for key, value in tqdm(lora_state.items(), desc="Moving LoRA to device"):
-            lora_state[key] = value.to(device)
+    LOG.debug("Keeping LoRA weights on CPU; will move per-tensor during merge")

151-159: Open safetensors shards on CPU to control VRAM; move per-tensor for compute

Loading tensors directly on GPU risks VRAM blowups and makes serialization harder.

-        if shard_path.suffix == ".safetensors":
-            with safetensors.safe_open(shard_path, framework="pt", device=device) as f:
+        if shard_path.suffix == ".safetensors":
+            # Always open on CPU; move specific tensors to `device` only for compute
+            with safetensors.safe_open(shard_path, framework="pt", device="cpu") as f:

179-200: Load .bin shards on CPU; compute per-tensor on device; store CPU tensors; prefer weights_only=True

This keeps memory bounded, improves safety, and aligns with the safetensors branch.

-            state_dict = torch.load(
-                shard_path, map_location=device
-            )  # nosec B614: loading trusted model weights
+            state_dict = torch.load(  # nosec B614: loading trusted model weights
+                shard_path, map_location="cpu", weights_only=True
+            )
@@
-                    original_dtype = tensor.dtype
-                    tensor_fp32 = tensor.to(torch.float32)
-                    lora_a_fp32 = lora_a.to(torch.float32)
-                    lora_b_fp32 = lora_b.to(torch.float32)
-
-                    delta = scale * (lora_b_fp32 @ lora_a_fp32)
-                    merged_tensor = (tensor_fp32 + delta).to(original_dtype)
-                    merged_tensors[key] = merged_tensor
-
-                    del tensor_fp32, lora_a_fp32, lora_b_fp32, delta
+                    original_dtype = tensor.dtype
+                    base_fp32 = tensor.to(device).to(torch.float32)
+                    a_fp32 = lora_a.to(device).to(torch.float32)
+                    b_fp32 = lora_b.to(device).to(torch.float32)
+                    delta = scale * (b_fp32 @ a_fp32)
+                    if bool(getattr(lora_config, "fan_in_fan_out", False) or getattr(lora_config, "lora_fan_in_fan_out", False)):
+                        delta = delta.T
+                    merged_tensors[key] = (base_fp32 + delta).to(original_dtype).detach().cpu()
+                    del base_fp32, a_fp32, b_fp32, delta
                 else:
-                    merged_tensors[key] = tensor
+                    merged_tensors[key] = tensor.detach().cpu()

201-209: Do not mislabel file formats; preserve original shard extension and write CPU tensors

Currently, when input is .bin and safe_tensors=True, you rename to .safetensors but still use torch.save, producing invalid safetensors and breaking HF indices. Also ensure all tensors are on CPU before saving.

-        output_shard_path = output_path / shard_path.name
-        if safe_tensors and shard_path.suffix == ".safetensors":
-            safetensors.torch.save_file(
-                merged_tensors, output_shard_path, metadata=metadata
-            )
-        else:
-            if safe_tensors:
-                output_shard_path = output_shard_path.with_suffix(".safetensors")
-            torch.save(merged_tensors, output_shard_path)
+        output_shard_path = output_path / shard_path.name
+        # Ensure CPU tensors before writing
+        merged_tensors = {k: v.detach().cpu() for k, v in merged_tensors.items()}
+        if shard_path.suffix == ".safetensors":
+            safetensors.torch.save_file(merged_tensors, output_shard_path, metadata=metadata)
+        else:
+            if safe_tensors:
+                LOG.warning(
+                    "safe_tensors=True requested but input shards are .bin; preserving .bin format "
+                    "to avoid index mismatches."
+                )
+            torch.save(merged_tensors, output_shard_path)

167-175: Compute on device, respect fan_in_fan_out, and store CPU tensors for save

  • Do compute on device in FP32, but store merged results on CPU.
  • Handle PEFT’s fan_in_fan_out=True by transposing delta.
  • Avoid keeping GPU tensors in merged_tensors (safetensors requires CPU tensors).
-                        original_dtype = tensor.dtype
-                        tensor_fp32 = tensor.to(torch.float32)
-                        lora_a_fp32 = lora_a.to(torch.float32)
-                        lora_b_fp32 = lora_b.to(torch.float32)
-
-                        delta = scale * (lora_b_fp32 @ lora_a_fp32)
-                        merged_tensor = (tensor_fp32 + delta).to(original_dtype)
-                        merged_tensors[key] = merged_tensor
-                        del tensor_fp32, lora_a_fp32, lora_b_fp32, delta
+                        original_dtype = tensor.dtype
+                        base_fp32 = tensor.to(device).to(torch.float32)
+                        a_fp32 = lora_a.to(device).to(torch.float32)
+                        b_fp32 = lora_b.to(device).to(torch.float32)
+                        delta = scale * (b_fp32 @ a_fp32)
+                        if bool(getattr(lora_config, "fan_in_fan_out", False) or getattr(lora_config, "lora_fan_in_fan_out", False)):
+                            delta = delta.T
+                        merged_tensors[key] = (base_fp32 + delta).to(original_dtype).detach().cpu()
+                        del base_fp32, a_fp32, b_fp32, delta
🧹 Nitpick comments (5)
src/axolotl/utils/lora_merge_efficient.py (5)

32-38: Stop scanning once both LoRA tensors are found

Micro-optimization: break early to avoid O(N) scan over all adapter tensors per key.

-    for lora_key, lora_weight in lora_state.items():
+    for lora_key, lora_weight in lora_state.items():
         if lora_key.endswith(f"{clean_key}.lora_A.weight"):
             lora_a = lora_weight
         elif lora_key.endswith(f"{clean_key}.lora_B.weight"):
             lora_b = lora_weight
+        if lora_a is not None and lora_b is not None:
+            break

211-214: Guard CUDA cache calls

Avoid calling CUDA APIs when CUDA isn’t available.

-        if device != "cpu":
-            torch.cuda.empty_cache()
+        if device != "cpu" and torch.cuda.is_available():
+            torch.cuda.empty_cache()

21-41: Performance: avoid O(N×M) adapter scans by pre-indexing LoRA keys

Current per-tensor scan over the entire lora_state is costly on large models. Build a suffix-index map once.

Happy to provide a follow-up patch that constructs:

  • a map from cleaned base key suffix → (A, B)
  • or a trie/suffix map keyed by last two path segments

Let me know if you want the diff.

Also applies to: 137-146


112-115: Correctness: ensure LoRA orientation (fan_in_fan_out) is supported

Adapters trained with fan_in_fan_out=True require transposition during merge; the proposed diffs add this. Please add a unit test exercising this flag.

Also applies to: 167-175, 188-194


137-143: Log improvements: shard count and adapter path

Add low-noise debug logs to aid support; aligns with reviewer suggestions.

-    LOG.debug(f"Found {len(model_shards)} model shards")
+    LOG.debug(f"Found {len(model_shards)} model shards in {base_model_path}")
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 7145972 and 0cc2792.

📒 Files selected for processing (1)
  • src/axolotl/utils/lora_merge_efficient.py (1 hunks)
🧰 Additional context used
🧠 Learnings (4)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.455Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.
📚 Learning: 2025-08-22T13:23:41.455Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.455Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.

Applied to files:

  • src/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.411Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.411Z
Learning: HuggingFace transformers uses standard patterns `pytorch_model*.bin` and `model*.safetensors` for model shards, as defined in transformers/utils/__init__.py. Additional patterns like `pytorch_model*.safetensors` are not necessary for standard HF model discovery.

Applied to files:

  • src/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.411Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.411Z
Learning: HuggingFace transformers uses these standard filename patterns: WEIGHTS_NAME = "pytorch_model.bin", SAFE_WEIGHTS_NAME = "model.safetensors" (not "pytorch_model.safetensors"), and sharded files follow "pytorch_model-*.bin" and "model-*.safetensors" patterns. The patterns "pytorch_model*.bin" and "model*.safetensors" are sufficient for discovering HF model shards.

Applied to files:

  • src/axolotl/utils/lora_merge_efficient.py
🧬 Code graph analysis (1)
src/axolotl/utils/lora_merge_efficient.py (2)
src/axolotl/utils/schemas/peft.py (1)
  • LoraConfig (28-191)
src/axolotl/utils/logging.py (1)
  • get_logger (42-49)
🪛 Ruff (0.12.2)
src/axolotl/utils/lora_merge_efficient.py

110-110: Avoid specifying long messages outside the exception class

(TRY003)


121-123: Avoid specifying long messages outside the exception class

(TRY003)


139-139: Avoid specifying long messages outside the exception class

(TRY003)

🪛 GitHub Actions: lint
src/axolotl/utils/lora_merge_efficient.py

[error] 176-179: ruff-format: 1 file reformatted by the hook; please re-run 'pre-commit run --all-files' and commit the changes.

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.8.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: preview
🔇 Additional comments (2)
src/axolotl/utils/lora_merge_efficient.py (2)

139-139: Run pre-commit locally and centralize exception messages

Pre-commit isn’t available here; please run

pre-commit run --all-files

locally to apply any required formatting fixes. Also shorten or centralize long exception messages (e.g. at lines 139 and 215).


215-215: Add automated parity and VRAM measurement tests
CI failed to run the parity script due to missing torch—ensure your test environment installs PyTorch and add a CI step that:

  • Runs merge_lora_sharded_efficient and merge_lora on the same base+adapter
  • Asserts all tensors match (rtol=1e-4, atol=1e-5)
  • Captures peak GPU memory during each merge to validate the expected savings

Comment on lines 86 to 96
def merge_lora_sharded_efficient(
base_model_path: Union[str, Path],
lora_adapter_path: Union[str, Path],
output_path: Union[str, Path],
device: str = "cuda",
safe_tensors: bool = True,
) -> None:
"""
Memory-efficient LoRA merging that processes shards individually
without loading the full model into memory.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Safety: prevent in-place overwrite of source directory

If output_path equals base_model_path, you risk clobbering source files. Add a guard and fail fast.

     output_path = Path(output_path)
@@
-    os.makedirs(output_path, exist_ok=True)
+    if output_path.resolve() == base_model_path.resolve():
+        raise ValueError("output_path must differ from base_model_path to avoid overwriting source shards")
+    os.makedirs(output_path, exist_ok=True)

Also applies to: 101-106

🤖 Prompt for AI Agents
In src/axolotl/utils/lora_merge_efficient.py around lines 86-96 (and similarly
at 101-106), add a fail-fast guard that prevents output_path from being the same
as base_model_path (and also disallow matching lora_adapter_path) to avoid
in-place overwrites; implement by converting inputs to pathlib.Path and
comparing resolved absolute paths (Path(...).resolve()) and if any matches raise
a clear ValueError (or SystemExit) with a message like "output_path must be
different from base_model_path/lora_adapter_path" before performing any file
operations.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (11)
src/axolotl/utils/lora_merge_efficient.py (9)

11-14: Fix import: NameError on safetensors.safe_open

You call safetensors.safe_open but never import the top-level module.

-import safetensors.torch
+import safetensors  # needed for safetensors.safe_open
+import safetensors.torch

45-49: Initialize list correctly

listPath is a type hint, not a constructor; this will raise.

-    shards = list[Path]()
+    shards: list[Path] = []

72-81: Skip all HF weight files not selected as shards

Avoid copying opposite-format weights into output (mixed repo).

-        if filepath.name in shard_names:
-            continue
-        if filepath.suffix == ".gguf":
-            continue
-        if filepath.name.startswith("model") and filepath.suffix == ".safetensors":
-            continue
+        if filepath.name in shard_names:
+            continue
+        # Skip any other HF weight files
+        if (
+            (filepath.suffix in {".safetensors", ".bin"})
+            and (filepath.name.startswith("model") or filepath.name.startswith("pytorch_model"))
+        ):
+            continue
+        if filepath.suffix == ".gguf":
+            continue

101-106: Guard against in-place overwrite of source dirs

Prevent output_path == base_model_path or lora_adapter_path.

-    os.makedirs(output_path, exist_ok=True)
+    if output_path.resolve() in {base_model_path.resolve(), lora_adapter_path.resolve()}:
+        raise ValueError("output_path must differ from base_model_path and lora_adapter_path")
+    os.makedirs(output_path, exist_ok=True)

112-115: Access LoraConfig attributes, validate r/alpha

LoraConfig is an object, not a dict; also guard r>0 and presence of lora_alpha.

-    lora_config = LoraConfig.from_json_file(config_file)
-    scale = lora_config["lora_alpha"] / lora_config["r"]
+    lora_config = LoraConfig.from_json_file(config_file)
+    if not getattr(lora_config, "r", None) or not getattr(lora_config, "lora_alpha", None):
+        raise ValueError("LoRA config missing required fields: 'r' and 'lora_alpha'")
+    if int(lora_config.r) == 0:
+        raise ValueError("LoRA config 'r' must be > 0")
+    scale = float(lora_config.lora_alpha) / float(lora_config.r)

132-136: Keep LoRA on CPU; avoid bulk GPU transfer

Bulk moving defeats “memory-efficient” goal and can OOM.

-    if device != "cpu":
-        LOG.info(f"Moving LoRA weights to {device}")
-        for key, value in tqdm(lora_state.items(), desc="Moving LoRA to device"):
-            lora_state[key] = value.to(device)
+    LOG.debug("Keeping LoRA weights on CPU; moving per-tensor during merge")

151-176: Open safetensors on CPU; per-tensor device compute; handle fan_in_fan_out; store CPU tensors

Prevents VRAM spikes and ensures correct orientation; outputs must be CPU tensors for serialization.

-        if shard_path.suffix == ".safetensors":
-            with safetensors.safe_open(shard_path, framework="pt", device=device) as f:
+        if shard_path.suffix == ".safetensors":
+            with safetensors.safe_open(shard_path, framework="pt", device="cpu") as f:
                 if hasattr(f, "metadata") and f.metadata():
                     metadata = f.metadata()
@@
-                    tensor = f.get_tensor(key)
+                    tensor = f.get_tensor(key)  # CPU tensor
                     lora_a, lora_b = find_lora_weights(lora_state, key)
@@
-                        original_dtype = tensor.dtype
-                        tensor_fp32 = tensor.to(torch.float32)
-
-                        delta = scale * (
-                            lora_b.to(torch.float32) @ lora_a.to(torch.float32)
-                        )
-
-                        merged_tensor = (tensor_fp32 + delta).to(original_dtype)
-                        merged_tensors[key] = merged_tensor
+                        original_dtype = tensor.dtype
+                        base_fp32 = tensor.to(device).to(torch.float32)
+                        a_fp32 = lora_a.to(device).to(torch.float32)
+                        b_fp32 = lora_b.to(device).to(torch.float32)
+                        delta = scale * (b_fp32 @ a_fp32)
+                        if bool(getattr(lora_config, "fan_in_fan_out", False) or getattr(lora_config, "lora_fan_in_fan_out", False)):
+                            delta = delta.T
+                        merged_tensors[key] = (base_fp32 + delta).to(original_dtype).detach().cpu()
                     else:
-                        merged_tensors[key] = tensor
+                        merged_tensors[key] = tensor.detach().cpu()

179-196: Load .bin shards on CPU; per-tensor device compute; store CPU tensors; handle fan_in_fan_out

Avoids GPU load of entire shard; adds robustness with weights_only.

-        else:
-            state_dict = torch.load(
-                shard_path, map_location=device
-            )  # nosec B614: loading trusted model weights
+        else:
+            state_dict = torch.load(  # nosec B614: loading trusted model weights
+                shard_path, map_location="cpu", weights_only=True
+            )
             for key, tensor in state_dict.items():
                 total_tensors += 1
                 lora_a, lora_b = find_lora_weights(lora_state, key)
@@
-                if lora_a is not None and lora_b is not None:
+                if lora_a is not None and lora_b is not None:
                     merged_count += 1
                     original_dtype = tensor.dtype
-                    tensor_fp32 = tensor.to(torch.float32)
-                    delta = scale * (
-                        lora_b.to(torch.float32) @ lora_a.to(torch.float32)
-                    )
-                    merged_tensors[key] = (tensor_fp32 + delta).to(original_dtype)
+                    base_fp32 = tensor.to(device).to(torch.float32)
+                    a_fp32 = lora_a.to(device).to(torch.float32)
+                    b_fp32 = lora_b.to(device).to(torch.float32)
+                    delta = scale * (b_fp32 @ a_fp32)
+                    if bool(getattr(lora_config, "fan_in_fan_out", False) or getattr(lora_config, "lora_fan_in_fan_out", False)):
+                        delta = delta.T
+                    merged_tensors[key] = (base_fp32 + delta).to(original_dtype).detach().cpu()
                 else:
-                    merged_tensors[key] = tensor
+                    merged_tensors[key] = tensor.detach().cpu()

197-206: Do not rename .bin shards to .safetensors; preserve original format

Renaming while using torch.save corrupts repos and breaks indices. Always save CPU tensors.

-        output_shard_path = output_path / shard_path.name
-        if safe_tensors and shard_path.suffix == ".safetensors":
-            safetensors.torch.save_file(
-                merged_tensors, output_shard_path, metadata=metadata
-            )
-        else:
-            if safe_tensors:
-                output_shard_path = output_shard_path.with_suffix(".safetensors")
-            torch.save(merged_tensors, output_shard_path)
+        output_shard_path = output_path / shard_path.name
+        # Ensure CPU tensors before writing
+        merged_tensors = {k: v.detach().cpu() for k, v in merged_tensors.items()}
+        if shard_path.suffix == ".safetensors":
+            safetensors.torch.save_file(merged_tensors, output_shard_path, metadata=metadata)
+        else:
+            if safe_tensors:
+                LOG.warning(
+                    "safe_tensors=True requested but input shards are .bin; preserving .bin format to avoid index mismatches."
+                )
+            torch.save(merged_tensors, output_shard_path)
src/axolotl/cli/merge_lora.py (2)

24-36: Normalize merge_method; accept 'standard'; broaden fallback; log exception

Matches PR wording and avoids silent failures.

-    merge_method = getattr(cfg, "merge_method", "memory_efficient")
-    LOG.info(f"Using {merge_method} LoRA merge method")
-
-    if merge_method == "legacy":
-        _do_merge_lora_legacy(cfg=cfg)
-    else:
-        try:
-            _do_merge_lora_efficient(cfg=cfg)
-        except RuntimeError as e:
-            LOG.error(f"Memory-efficient merge failed: {e}")
-            LOG.info("Falling back to legacy merge method...")
-            _do_merge_lora_legacy(cfg=cfg)
+    merge_method = str(getattr(cfg, "merge_method", "memory_efficient")).lower().replace("-", "_")
+    if merge_method in {"legacy", "standard"}:
+        LOG.info("Using legacy LoRA merge method...")
+        _do_merge_lora_legacy(cfg=cfg)
+    else:
+        LOG.info("Using memory-efficient LoRA merge method...")
+        try:
+            _do_merge_lora_efficient(cfg=cfg)
+        except Exception:
+            LOG.exception("Memory-efficient merge failed; falling back to legacy merge...")
+            _do_merge_lora_legacy(cfg=cfg)

80-91: Pass device explicitly; default to CPU when CUDA unavailable

Prevents crashes on CPU-only hosts and avoids hidden GPU use.

-    output_path = Path(cfg.output_dir) / "merged"
-    safe_tensors = getattr(cfg, "save_safetensors", True)
+    output_path = Path(cfg.output_dir) / "merged"
+    safe_tensors = getattr(cfg, "save_safetensors", True)
+    # Choose device: prefer CUDA if available
+    try:
+        import torch  # local import
+        device = "cuda" if torch.cuda.is_available() else "cpu"
+    except Exception:
+        device = "cpu"
@@
-    merge_lora_sharded_efficient(
+    merge_lora_sharded_efficient(
         base_model_path=cfg.base_model,
         lora_adapter_path=cfg.lora_model_dir,
         output_path=output_path,
-        safe_tensors=safe_tensors,
+        safe_tensors=safe_tensors,
+        device=device,
     )
🧹 Nitpick comments (1)
src/axolotl/utils/lora_merge_efficient.py (1)

208-209: Guard CUDA cache calls

Avoid calling CUDA APIs on CPU-only hosts.

-        if device != "cpu":
-            torch.cuda.empty_cache()
+        if device != "cpu" and torch.cuda.is_available():
+            torch.cuda.empty_cache()
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 0cc2792 and 609f07b.

📒 Files selected for processing (2)
  • src/axolotl/cli/merge_lora.py (3 hunks)
  • src/axolotl/utils/lora_merge_efficient.py (1 hunks)
🧰 Additional context used
🧠 Learnings (4)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.455Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.
📚 Learning: 2025-08-22T13:23:41.455Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.455Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.

Applied to files:

  • src/axolotl/cli/merge_lora.py
  • src/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.411Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.411Z
Learning: HuggingFace transformers uses standard patterns `pytorch_model*.bin` and `model*.safetensors` for model shards, as defined in transformers/utils/__init__.py. Additional patterns like `pytorch_model*.safetensors` are not necessary for standard HF model discovery.

Applied to files:

  • src/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.411Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.411Z
Learning: HuggingFace transformers uses these standard filename patterns: WEIGHTS_NAME = "pytorch_model.bin", SAFE_WEIGHTS_NAME = "model.safetensors" (not "pytorch_model.safetensors"), and sharded files follow "pytorch_model-*.bin" and "model-*.safetensors" patterns. The patterns "pytorch_model*.bin" and "model*.safetensors" are sufficient for discovering HF model shards.

Applied to files:

  • src/axolotl/utils/lora_merge_efficient.py
🧬 Code graph analysis (2)
src/axolotl/cli/merge_lora.py (3)
src/axolotl/utils/lora_merge_efficient.py (1)
  • merge_lora_sharded_efficient (86-211)
src/axolotl/utils/logging.py (1)
  • get_logger (42-49)
src/axolotl/utils/dict.py (1)
  • DictDefault (6-38)
src/axolotl/utils/lora_merge_efficient.py (2)
src/axolotl/utils/schemas/peft.py (1)
  • LoraConfig (28-191)
src/axolotl/utils/logging.py (1)
  • get_logger (42-49)
🪛 Ruff (0.12.2)
src/axolotl/cli/merge_lora.py

33-33: Use logging.exception instead of logging.error

Replace with exception

(TRY400)

src/axolotl/utils/lora_merge_efficient.py

110-110: Avoid specifying long messages outside the exception class

(TRY003)


121-123: Avoid specifying long messages outside the exception class

(TRY003)


139-139: Avoid specifying long messages outside the exception class

(TRY003)

🪛 GitHub Actions: lint
src/axolotl/utils/lora_merge_efficient.py

[error] 176-176: ruff-format reformatted this file during pre-commit; the pre-commit hook failed with exit code 1. Run 'pre-commit run --all-files' and commit the updated file to apply formatting.

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
  • GitHub Check: preview
  • GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.8.0)
🔇 Additional comments (1)
src/axolotl/utils/lora_merge_efficient.py (1)

1-212: Install and run pre-commit to fix formatting
pre-commit wasn’t found in this environment. Ensure it’s installed and run

pre-commit run --all-files

to apply lint/format fixes before committing.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (4)
src/axolotl/cli/merge_lora.py (4)

80-80: Prefer debug level for this trace message (consistency with other suggestions).

-    LOG.info("Using memory-efficient LoRA merging method...")
+    LOG.debug("Using memory-efficient LoRA merging method...")

127-127: Fix wording: this path points to adapter weights, not a merged model directory.

-            f"Target directory for LoRA merged model does not exist: `{parsed_cfg.lora_model_dir}`"
+            f"Target directory for LoRA adapter weights does not exist: `{parsed_cfg.lora_model_dir}`"

24-35: Normalize merge_method, accept 'standard', and guarantee fallback with stacktrace.

Without normalization, merge_method: standard (as advertised in the PR) routes to the efficient path unintentionally. Also, catching only RuntimeError can miss common failures (e.g., CPU-only hosts raising AssertionError when device defaults to "cuda"), preventing the promised fallback. Use LOG.exception to keep the stacktrace (ruff TRY400).

Apply:

-    merge_method = getattr(cfg, "merge_method", "memory_efficient")
-    LOG.info(f"Using {merge_method} LoRA merge method")
-
-    if merge_method == "legacy":
-        _do_merge_lora_legacy(cfg=cfg)
-    else:
-        try:
-            _do_merge_lora_efficient(cfg=cfg)
-        except RuntimeError as e:
-            LOG.error(f"Memory-efficient merge failed: {e}")
-            LOG.info("Falling back to legacy merge method...")
-            _do_merge_lora_legacy(cfg=cfg)
+    merge_method = str(getattr(cfg, "merge_method", "memory_efficient") or "memory_efficient").strip().lower().replace("-", "_")
+    if merge_method in {"legacy", "standard"}:
+        LOG.debug("Using legacy LoRA merging method...")
+        _do_merge_lora_legacy(cfg=cfg)
+    else:
+        LOG.debug("Using memory-efficient LoRA merging method...")
+        try:
+            _do_merge_lora_efficient(cfg=cfg)
+        except Exception:
+            LOG.exception("Memory-efficient merge failed; falling back to legacy.")
+            _do_merge_lora_legacy(cfg=cfg)

82-91: Pass device explicitly to avoid CPU-only crashes and unnecessary fallbacks.

The efficient helper defaults to device="cuda". On CPU-only hosts this will throw before merging. Pick device based on CUDA availability and pass it through.

     output_path = Path(cfg.output_dir) / "merged"
     safe_tensors = getattr(cfg, "save_safetensors", True)

     # Perform memory-efficient merge
+    try:
+        import torch  # local import to keep CLI lightweight
+        device = "cuda" if torch.cuda.is_available() else "cpu"
+    except Exception:
+        device = "cpu"
     merge_lora_sharded_efficient(
         base_model_path=cfg.base_model,
         lora_adapter_path=cfg.lora_model_dir,
         output_path=output_path,
-        safe_tensors=safe_tensors,
+        safe_tensors=safe_tensors,
+        device=device,
     )
🧹 Nitpick comments (1)
src/axolotl/cli/merge_lora.py (1)

45-46: Unify safe-serialization default with efficient path.

Use getattr with a default instead of identity comparison for consistent behavior.

-    safe_serialization = cfg.save_safetensors is True
+    safe_serialization = getattr(cfg, "save_safetensors", True)
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 609f07b and 4b2fc64.

📒 Files selected for processing (1)
  • src/axolotl/cli/merge_lora.py (3 hunks)
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.455Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.
📚 Learning: 2025-08-22T13:23:41.455Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.455Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.

Applied to files:

  • src/axolotl/cli/merge_lora.py
🧬 Code graph analysis (1)
src/axolotl/cli/merge_lora.py (3)
src/axolotl/utils/lora_merge_efficient.py (1)
  • merge_lora_sharded_efficient (86-211)
src/axolotl/utils/logging.py (1)
  • get_logger (42-49)
src/axolotl/utils/dict.py (1)
  • DictDefault (6-38)
🪛 Ruff (0.12.2)
src/axolotl/cli/merge_lora.py

33-33: Use logging.exception instead of logging.error

Replace with exception

(TRY400)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
  • GitHub Check: preview
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.8.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
🔇 Additional comments (4)
src/axolotl/cli/merge_lora.py (4)

12-12: Import of efficient merge helper looks good.


19-19: Docstring accurately reflects dual strategies.


38-44: Legacy merge entrypoint/readability is fine.


93-93: Success debug message reads well.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (9)
src/axolotl/utils/lora_merge_efficient.py (9)

27-40: Do not use rstrip(".weight"); exact suffix and O(1) lookups.

rstrip removes any trailing chars in the set ".weight", corrupting keys and causing mis-merges.

Apply:

-    clean_key = key.rstrip(".weight")
-
-    lora_a = None
-    lora_b = None
-
-    for lora_key, lora_weight in lora_state.items():
-        if lora_key.endswith(f"{clean_key}.lora_A.weight"):
-            lora_a = lora_weight
-        elif lora_key.endswith(f"{clean_key}.lora_B.weight"):
-            lora_b = lora_weight
-
-    if lora_a is not None and lora_b is not None:
-        return lora_a, lora_b
-    return None, None
+    clean_key = key[:-7] if key.endswith(".weight") else key
+    a_key = f"{clean_key}.lora_A.weight"
+    b_key = f"{clean_key}.lora_B.weight"
+    lora_a = lora_state.get(a_key)
+    lora_b = lora_state.get(b_key)
+    if lora_a is not None and lora_b is not None:
+        return lora_a, lora_b
+    return None, None

45-55: listPath is not a constructor.

This will raise at runtime.

Apply:

-    shards = list[Path]()
+    shards: list[Path] = []

97-106: Guard against overwriting source dirs.

Disallow output_path == base_model_path or == lora_adapter_path.

Apply:

     output_path = Path(output_path)
@@
-    os.makedirs(output_path, exist_ok=True)
+    if output_path.resolve() == base_model_path.resolve():
+        raise ValueError("output_path must differ from base_model_path to avoid overwriting source shards")
+    if output_path.resolve() == lora_adapter_path.resolve():
+        raise ValueError("output_path must differ from lora_adapter_path")
+    os.makedirs(output_path, exist_ok=True)

132-136: Keep LoRA on CPU; avoid bulk device transfer.

Bulk .to(device) defeats the memory-efficient goal and can OOM.

Apply:

-    if device != "cpu":
-        LOG.debug(f"Moving LoRA weights to {device}")
-        for key, value in tqdm(lora_state.items(), desc="Moving LoRA to device"):
-            lora_state[key] = value.to(device)
+    # Keep LoRA on CPU; move per-tensor during merge
+    LOG.debug("Keeping LoRA weights on CPU; will move per-tensor during merge")

195-204: Do not write PyTorch pickles with a .safetensors extension; preserve input shard format.

Renaming .bin to .safetensors while calling torch.save corrupts outputs and breaks HF indices.

Apply:

-        output_shard_path = output_path / shard_path.name
-        if safe_tensors and shard_path.suffix == ".safetensors":
-            safetensors.torch.save_file(
-                merged_tensors, output_shard_path, metadata=metadata
-            )
-        else:
-            if safe_tensors:
-                output_shard_path = output_shard_path.with_suffix(".safetensors")
-            torch.save(merged_tensors, output_shard_path)
+        output_shard_path = output_path / shard_path.name
+        # Ensure CPU tensors before writing
+        merged_tensors = {k: v.detach().cpu() for k, v in merged_tensors.items()}
+        if shard_path.suffix == ".safetensors":
+            safetensors.torch.save_file(merged_tensors, output_shard_path, metadata=metadata)
+        else:
+            if safe_tensors:
+                LOG.warning(
+                    "safe_tensors=True requested but input shards are .bin; preserving .bin format "
+                    "to avoid index mismatches. Conversion is not implemented here."
+                )
+            torch.save(merged_tensors, output_shard_path)

11-14: Fix import for safetensors.safe_open (NameError).

You call safetensors.safe_open but never import the top-level safetensors pkg.

Apply:

-import safetensors.torch
+import safetensors  # needed for safetensors.safe_open
+import safetensors.torch

112-116: LoraConfig is an object, not a dict; also guard r > 0.

Current code will raise; division by zero risk.

Apply:

-    lora_config = LoraConfig.from_json_file(config_file)
-    scale = lora_config["lora_alpha"] / lora_config["r"]
+    lora_config = LoraConfig.from_json_file(config_file)
+    if not getattr(lora_config, "r", None):
+        raise ValueError("LoRA config 'r' must be > 0")
+    scale = float(lora_config.lora_alpha) / float(lora_config.r)

151-177: Open safetensors on CPU; per-tensor JIT moves; support fan_in_fan_out; store results on CPU.

Current code opens on device, does FP32 on host device without ensuring device alignment, and ignores fan_in_fan_out.

Apply:

-        if shard_path.suffix == ".safetensors":
-            with safetensors.safe_open(shard_path, framework="pt", device=device) as f:
+        if shard_path.suffix == ".safetensors":
+            with safetensors.safe_open(shard_path, framework="pt", device="cpu") as f:
                 if hasattr(f, "metadata") and f.metadata():
                     metadata = f.metadata()
@@
-                for key in f.keys():
+                for key in f.keys():
                     total_tensors += 1
-                    tensor = f.get_tensor(key)
+                    tensor = f.get_tensor(key)  # CPU tensor
                     lora_a, lora_b = find_lora_weights(lora_state, key)
@@
-                    if lora_a is not None and lora_b is not None:
+                    if lora_a is not None and lora_b is not None:
                         merged_count += 1
@@
-                        original_dtype = tensor.dtype
-                        tensor_fp32 = tensor.to(torch.float32)
-
-                        delta = scale * (
-                            lora_b.to(torch.float32) @ lora_a.to(torch.float32)
-                        )
-
-                        merged_tensor = (tensor_fp32 + delta).to(original_dtype)
-                        merged_tensors[key] = merged_tensor
+                        original_dtype = tensor.dtype
+                        base_fp32 = tensor.to(device).to(torch.float32)
+                        a_fp32 = lora_a.to(device).to(torch.float32)
+                        b_fp32 = lora_b.to(device).to(torch.float32)
+                        delta = scale * (b_fp32 @ a_fp32)
+                        if bool(getattr(lora_config, "fan_in_fan_out", False) or getattr(lora_config, "lora_fan_in_fan_out", False)):
+                            delta = delta.T
+                        merged_tensors[key] = (base_fp32 + delta).to(original_dtype).detach().cpu()
                     else:
-                        merged_tensors[key] = tensor
+                        merged_tensors[key] = tensor.detach().cpu()

178-194: Load .bin shards on CPU; weights_only; per-tensor JIT moves; support fan_in_fan_out; store on CPU.

Avoid GPU preload; ensure robustness and memory efficiency.

Apply:

-        else:
-            state_dict = torch.load(shard_path, map_location=device)  # nosec B614: loading trusted model weights
+        else:
+            state_dict = torch.load(  # nosec B614: loading trusted model weights
+                shard_path, map_location="cpu", weights_only=True
+            )
             for key, tensor in state_dict.items():
                 total_tensors += 1
                 lora_a, lora_b = find_lora_weights(lora_state, key)
@@
-                if lora_a is not None and lora_b is not None:
+                if lora_a is not None and lora_b is not None:
                     merged_count += 1
                     original_dtype = tensor.dtype
-                    tensor_fp32 = tensor.to(torch.float32)
-                    delta = scale * (
-                        lora_b.to(torch.float32) @ lora_a.to(torch.float32)
-                    )
-                    merged_tensors[key] = (tensor_fp32 + delta).to(original_dtype)
+                    base_fp32 = tensor.to(device).to(torch.float32)
+                    a_fp32 = lora_a.to(device).to(torch.float32)
+                    b_fp32 = lora_b.to(device).to(torch.float32)
+                    delta = scale * (b_fp32 @ a_fp32)
+                    if bool(getattr(lora_config, "fan_in_fan_out", False) or getattr(lora_config, "lora_fan_in_fan_out", False)):
+                        delta = delta.T
+                    merged_tensors[key] = (base_fp32 + delta).to(original_dtype).detach().cpu()
                 else:
-                    merged_tensors[key] = tensor
+                    merged_tensors[key] = tensor.detach().cpu()
🧹 Nitpick comments (2)
src/axolotl/utils/lora_merge_efficient.py (2)

72-81: Skip opposite-format HF weight files when copying non-model files.

Prevents mixed-format outputs (e.g., copying stray .bin when merging safetensors).

Apply:

-        if filepath.name.startswith("model") and filepath.suffix == ".safetensors":
-            continue
+        # Skip any other HF weight files not selected as shards
+        if (
+            (filepath.name.startswith("model") and filepath.suffix == ".safetensors")
+            or (filepath.name.startswith("pytorch_model") and filepath.suffix == ".bin")
+        ):
+            continue

205-209: Free host memory too.

Add gc.collect() after each shard; keep empty_cache as-is.

Apply:

+        import gc
         del merged_tensors
         if device != "cpu":
             torch.cuda.empty_cache()
+        gc.collect()
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 4b2fc64 and 1c1c3ab.

📒 Files selected for processing (1)
  • src/axolotl/utils/lora_merge_efficient.py (1 hunks)
🧰 Additional context used
🧠 Learnings (4)
📓 Common learnings
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.455Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.
📚 Learning: 2025-08-22T13:23:41.455Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/cli/merge_lora.py:65-81
Timestamp: 2025-08-22T13:23:41.455Z
Learning: The `lora_on_cpu` configuration in Axolotl is only relevant when loading the full model into memory (standard LoRA merge approach), not when processing individual shards in the memory-efficient approach.

Applied to files:

  • src/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.411Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.411Z
Learning: HuggingFace transformers uses standard patterns `pytorch_model*.bin` and `model*.safetensors` for model shards, as defined in transformers/utils/__init__.py. Additional patterns like `pytorch_model*.safetensors` are not necessary for standard HF model discovery.

Applied to files:

  • src/axolotl/utils/lora_merge_efficient.py
📚 Learning: 2025-08-22T13:19:26.411Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3095
File: src/axolotl/utils/lora_merge_efficient.py:46-58
Timestamp: 2025-08-22T13:19:26.411Z
Learning: HuggingFace transformers uses these standard filename patterns: WEIGHTS_NAME = "pytorch_model.bin", SAFE_WEIGHTS_NAME = "model.safetensors" (not "pytorch_model.safetensors"), and sharded files follow "pytorch_model-*.bin" and "model-*.safetensors" patterns. The patterns "pytorch_model*.bin" and "model*.safetensors" are sufficient for discovering HF model shards.

Applied to files:

  • src/axolotl/utils/lora_merge_efficient.py
🧬 Code graph analysis (1)
src/axolotl/utils/lora_merge_efficient.py (2)
src/axolotl/utils/schemas/peft.py (1)
  • LoraConfig (28-191)
src/axolotl/utils/logging.py (1)
  • get_logger (42-49)
🪛 Ruff (0.12.2)
src/axolotl/utils/lora_merge_efficient.py

110-110: Avoid specifying long messages outside the exception class

(TRY003)


121-123: Avoid specifying long messages outside the exception class

(TRY003)


139-139: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.8.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: preview
🔇 Additional comments (1)
src/axolotl/utils/lora_merge_efficient.py (1)

86-93: Verify parity with legacy merge
Install torch and safetensors in your environment, then run the parity-check script on a small model. Confirm that the maximum absolute difference is 0.0 and there are no mismatched keys.

@ved1beta
Copy link
Contributor Author

ved1beta commented Sep 7, 2025

memory usage for both of the merges
calculated with a simple test script

Memory-Efficient Method
• Peak GPU Memory: 300 MB
• Peak CPU Memory: 14.4 MB
• Execution Time: 12.0 seconds

Legacy Method
• Peak GPU Memory: 2,914 MB
• Peak CPU Memory: 14.4 MB
• Execution Time: 15.9 seconds

Copy link
Member

@djsaunde djsaunde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw you added a VRAM / time comparison, can you specify which model you used?

Comment on lines 25 to 27
merge_method = (
str(getattr(cfg, "merge_method", "")).strip().lower().replace("-", "_")
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merge_method can only take values: Literal["legacy", "memory_efficient"] so you don't need this string handling.

merge_method = (
str(getattr(cfg, "merge_method", "")).strip().lower().replace("-", "_")
)
if merge_method in {"legacy", "standard"}:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"standard" doesn't exist

Comment on lines 33 to 37
try:
_do_merge_lora_efficient(cfg=cfg)
except Exception: # pylint: disable=broad-exception-caught
LOG.exception("Memory-efficient merge failed; falling back to legacy.")
_do_merge_lora_legacy(cfg=cfg)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tbh I'd rather have a hard failure here so we know if something is broken

Suggested change
try:
_do_merge_lora_efficient(cfg=cfg)
except Exception: # pylint: disable=broad-exception-caught
LOG.exception("Memory-efficient merge failed; falling back to legacy.")
_do_merge_lora_legacy(cfg=cfg)
_do_merge_lora_efficient(cfg=cfg)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there are unsupported combinations (you mentioned DoRA, RSLoRA), we should validate this in the pydantic model and raise an error there.

Comment on lines 185 to 187
lora_state = torch.load(lora_file, map_location="cpu", weights_only=True) # nosec B614
except TypeError:
lora_state = torch.load(lora_file, map_location="cpu") # nosec B614
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this try/except? can you choose one loading method and stick to it?

output_path = Path(output_path)

if "/" in str(base_model_path) and not base_model_path.exists():
from huggingface_hub import snapshot_download
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be a toplevel import

Comment on lines +56 to +84
def copy_non_model_files(
input_path: Path, output_path: Path, model_shards: list[Path]
) -> None:
"""
Copy all non-model files to the output directory.

Args:
input_path: Source directory
output_path: Destination directory
model_shards: List of model shard files to skip
"""
LOG.info("Copying non-model files to output directory...")

shard_names = {shard.name for shard in model_shards}

for filepath in input_path.glob("*"):
if filepath.is_dir():
continue
if filepath.name in shard_names:
continue
if (
filepath.name.startswith("model") and filepath.suffix == ".safetensors"
) or (filepath.name.startswith("pytorch_model") and filepath.suffix == ".bin"):
continue
if filepath.suffix == ".gguf":
continue

LOG.debug(f"Copying {filepath.name} to output")
shutil.copy2(filepath, output_path)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of this? Can you check that we / transformers don't already have a utility for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can remove this if we dont need files like config.json and all in output directory . not sure yes or no (did'nt found anything similar to this in transformers) for separating non-modules

Comment on lines 271 to 274
LOG.warning(
"safe_tensors=True requested but input shards are .bin; preserving .bin format "
"to avoid index mismatches."
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a bit confusing. if the user requests safe_tensors, shouldn't we convert them to safetensors?

Comment on lines 220 to 233
original_dtype = tensor.dtype
base_fp32 = tensor.to(device).to(torch.float32)
a_fp32 = lora_a.to(device).to(torch.float32)
b_fp32 = lora_b.to(device).to(torch.float32)
delta = scale * (b_fp32 @ a_fp32)
if bool(
lora_config_dict.get("fan_in_fan_out", False)
or lora_config_dict.get("lora_fan_in_fan_out", False)
):
delta = delta.T
merged_tensors[key] = (
(base_fp32 + delta).to(original_dtype).detach().cpu()
)
del base_fp32, a_fp32, b_fp32, delta
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this appears to be duplicated below, can be factored out into a helper method

@ved1beta ved1beta requested a review from djsaunde October 3, 2025 19:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants